From 21acda237744d4299e5bb449dce1ec8a1735f6de Mon Sep 17 00:00:00 2001 From: Ryan Blue Date: Fri, 4 May 2018 18:13:01 -0700 Subject: [PATCH 1/3] SPARK-24252: Add v2 data source mix-in for catalog support. --- .../catalog/v2/CaseInsensitiveStringMap.java | 107 ++++++++++ .../spark/sql/catalog/v2/CatalogProvider.java | 50 +++++ .../apache/spark/sql/catalog/v2/Catalogs.java | 109 +++++++++++ .../apache/spark/sql/catalog/v2/Table.java | 47 +++++ .../spark/sql/catalog/v2/TableCatalog.java | 137 +++++++++++++ .../spark/sql/catalog/v2/TableChange.java | 182 +++++++++++++++++ .../v2/CaseInsensitiveStringMapSuite.java | 48 +++++ .../sql/catalog/v2/CatalogLoadingSuite.java | 184 ++++++++++++++++++ .../org/apache/spark/sql/SparkSession.scala | 8 + 9 files changed, 872 insertions(+) create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/CaseInsensitiveStringMap.java create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/CatalogProvider.java create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/Catalogs.java create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/Table.java create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/TableCatalog.java create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/TableChange.java create mode 100644 sql/catalyst/src/test/java/org/apache/spark/sql/catalog/v2/CaseInsensitiveStringMapSuite.java create mode 100644 sql/catalyst/src/test/java/org/apache/spark/sql/catalog/v2/CatalogLoadingSuite.java diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/CaseInsensitiveStringMap.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/CaseInsensitiveStringMap.java new file mode 100644 index 000000000000..a4ad1f6994f9 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/CaseInsensitiveStringMap.java @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalog.v2; + +import java.util.Collection; +import java.util.HashMap; +import java.util.Locale; +import java.util.Map; +import java.util.Set; + +/** + * Case-insensitive map of string keys to string values. + *

+ * This is used to pass options to v2 implementations to ensure consistent case insensitivity. + *

+ * Methods that return keys in this map, like {@link #entrySet()} and {@link #keySet()}, return + * keys converted to lower case. + */ +public class CaseInsensitiveStringMap implements Map { + + public static CaseInsensitiveStringMap empty() { + return new CaseInsensitiveStringMap(); + } + + private final Map delegate; + + private CaseInsensitiveStringMap() { + this.delegate = new HashMap<>(); + } + + @Override + public int size() { + return delegate.size(); + } + + @Override + public boolean isEmpty() { + return delegate.isEmpty(); + } + + @Override + public boolean containsKey(Object key) { + return delegate.containsKey(key.toString().toLowerCase(Locale.ROOT)); + } + + @Override + public boolean containsValue(Object value) { + return delegate.containsValue(value); + } + + @Override + public String get(Object key) { + return delegate.get(key.toString().toLowerCase(Locale.ROOT)); + } + + @Override + public String put(String key, String value) { + return delegate.put(key.toLowerCase(Locale.ROOT), value); + } + + @Override + public String remove(Object key) { + return delegate.remove(key.toString().toLowerCase(Locale.ROOT)); + } + + @Override + public void putAll(Map m) { + for (Map.Entry entry : m.entrySet()) { + delegate.put(entry.getKey().toLowerCase(Locale.ROOT), entry.getValue()); + } + } + + @Override + public void clear() { + delegate.clear(); + } + + @Override + public Set keySet() { + return delegate.keySet(); + } + + @Override + public Collection values() { + return delegate.values(); + } + + @Override + public Set> entrySet() { + return delegate.entrySet(); + } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/CatalogProvider.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/CatalogProvider.java new file mode 100644 index 000000000000..03831b7aa915 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/CatalogProvider.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalog.v2; + +import org.apache.spark.sql.internal.SQLConf; + +/** + * A marker interface to provide a catalog implementation for Spark. + *

+ * Implementations can provide catalog functions by implementing additional interfaces, like + * {@link TableCatalog} to expose table operations. + *

+ * Catalog implementations must implement this marker interface to be loaded by + * {@link Catalogs#load(String, SQLConf)}. The loader will instantiate catalog classes using the + * required public no-arg constructor. After creating an instance, it will be configured by calling + * {@link #initialize(CaseInsensitiveStringMap)}. + *

+ * Catalog implementations are registered to a name by adding a configuration option to Spark: + * {@code spark.sql.catalog.catalog-name=com.example.YourCatalogClass}. All configuration properties + * in the Spark configuration that share the catalog name prefix, + * {@code spark.sql.catalog.catalog-name.(key)=(value)} will be passed in the case insensitive + * string map of options in initialization with the prefix removed. An additional property, + * {@code name}, is also added to the options and will contain the catalog's name; in this case, + * "catalog-name". + */ +public interface CatalogProvider { + /** + * Called to initialize configuration. + *

+ * This method is called once, just after the provider is instantiated. + * + * @param options a case-insensitive string map of configuration + */ + void initialize(CaseInsensitiveStringMap options); +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/Catalogs.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/Catalogs.java new file mode 100644 index 000000000000..71ab9f528dbe --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/Catalogs.java @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalog.v2; + +import org.apache.spark.SparkException; +import org.apache.spark.sql.internal.SQLConf; +import org.apache.spark.util.Utils; + +import java.util.Map; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +import static scala.collection.JavaConverters.mapAsJavaMapConverter; + +public class Catalogs { + private Catalogs() { + } + + /** + * Load and configure a catalog by name. + *

+ * This loads, instantiates, and initializes the catalog provider for each call; it does not + * cache or reuse instances. + * + * @param name a String catalog name + * @param conf a SQLConf + * @return an initialized CatalogProvider + * @throws SparkException If the provider class cannot be found or instantiated + */ + public static CatalogProvider load(String name, SQLConf conf) throws SparkException { + String providerClassName = conf.getConfString("spark.sql.catalog." + name, null); + if (providerClassName == null) { + throw new SparkException(String.format( + "Catalog '%s' provider not found: spark.sql.catalog.%s is not defined", name, name)); + } + + ClassLoader loader = Utils.getContextOrSparkClassLoader(); + + try { + Class providerClass = loader.loadClass(providerClassName); + + if (!CatalogProvider.class.isAssignableFrom(providerClass)) { + throw new SparkException(String.format( + "Provider class for catalog '%s' does not implement CatalogProvider: %s", + name, providerClassName)); + } + + CatalogProvider provider = CatalogProvider.class.cast(providerClass.newInstance()); + + provider.initialize(catalogOptions(name, conf)); + + return provider; + + } catch (ClassNotFoundException e) { + throw new SparkException(String.format( + "Cannot find catalog provider class for catalog '%s': %s", name, providerClassName)); + + } catch (IllegalAccessException e) { + throw new SparkException(String.format( + "Failed to call public no-arg constructor for catalog '%s': %s", name, providerClassName), + e); + + } catch (InstantiationException e) { + throw new SparkException(String.format( + "Failed while instantiating provider for catalog '%s': %s", name, providerClassName), + e.getCause()); + } + } + + /** + * Extracts a named catalog's configuration from a SQLConf. + * + * @param name a catalog name + * @param conf a SQLConf + * @return a case insensitive string map of options starting with spark.sql.catalog.(name). + */ + private static CaseInsensitiveStringMap catalogOptions(String name, SQLConf conf) { + Map allConfs = mapAsJavaMapConverter(conf.getAllConfs()).asJava(); + Pattern prefix = Pattern.compile("^spark\\.sql\\.catalog\\." + name + "\\.(.+)"); + + CaseInsensitiveStringMap options = CaseInsensitiveStringMap.empty(); + for (Map.Entry entry : allConfs.entrySet()) { + Matcher matcher = prefix.matcher(entry.getKey()); + if (matcher.matches() && matcher.groupCount() > 0) { + options.put(matcher.group(1), entry.getValue()); + } + } + + // add name last to ensure it overwrites any conflicting options + options.put("name", name); + + return options; + } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/Table.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/Table.java new file mode 100644 index 000000000000..30a20f27b8c6 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/Table.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalog.v2; + +import org.apache.spark.sql.catalyst.expressions.Expression; +import org.apache.spark.sql.types.StructType; + +import java.util.List; +import java.util.Map; + +/** + * Represents table metadata from a {@link TableCatalog} or other table sources. + */ +public interface Table { + /** + * Return the table properties. + * @return this table's map of string properties + */ + Map properties(); + + /** + * Return the table schema. + * @return this table's schema as a struct type + */ + StructType schema(); + + /** + * Return the table partitioning expressions. + * @return this table's partitioning expressions + */ + List partitionExpressions(); +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/TableCatalog.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/TableCatalog.java new file mode 100644 index 000000000000..539beb0c39c5 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/TableCatalog.java @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalog.v2; + +import org.apache.spark.sql.catalyst.TableIdentifier; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException; +import org.apache.spark.sql.catalyst.expressions.Expression; +import org.apache.spark.sql.types.StructType; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +public interface TableCatalog extends CatalogProvider { + /** + * Load table metadata by {@link TableIdentifier identifier} from the catalog. + * + * @param ident a table identifier + * @return the table's metadata + * @throws NoSuchTableException If the table doesn't exist. + */ + Table loadTable(TableIdentifier ident) throws NoSuchTableException; + + /** + * Test whether a table exists using an {@link TableIdentifier identifier} from the catalog. + * + * @param ident a table identifier + * @return true if the table exists, false otherwise + */ + default boolean tableExists(TableIdentifier ident) { + try { + return loadTable(ident) != null; + } catch (NoSuchTableException e) { + return false; + } + } + + /** + * Create a table in the catalog. + * + * @param ident a table identifier + * @param schema the schema of the new table, as a struct type + * @return metadata for the new table + * @throws TableAlreadyExistsException If a table already exists for the identifier + */ + default Table createTable(TableIdentifier ident, + StructType schema) throws TableAlreadyExistsException { + return createTable(ident, schema, Collections.emptyList(), Collections.emptyMap()); + } + + /** + * Create a table in the catalog. + * + * @param ident a table identifier + * @param schema the schema of the new table, as a struct type + * @param properties a string map of table properties + * @return metadata for the new table + * @throws TableAlreadyExistsException If a table already exists for the identifier + */ + default Table createTable(TableIdentifier ident, + StructType schema, + Map properties) throws TableAlreadyExistsException { + return createTable(ident, schema, Collections.emptyList(), properties); + } + + /** + * Create a table in the catalog. + * + * @param ident a table identifier + * @param schema the schema of the new table, as a struct type + * @param partitions a list of expressions to use for partitioning data in the table + * @param properties a string map of table properties + * @return metadata for the new table + * @throws TableAlreadyExistsException If a table already exists for the identifier + */ + Table createTable(TableIdentifier ident, + StructType schema, + List partitions, + Map properties) throws TableAlreadyExistsException; + + /** + * Apply a list of {@link TableChange changes} to a table in the catalog. + *

+ * Implementations may reject the requested changes. If any change is rejected, none of the + * changes should be applied to the table. + * + * @param ident a table identifier + * @param changes a list of changes to apply to the table + * @return updated metadata for the table + * @throws NoSuchTableException If the table doesn't exist. + * @throws IllegalArgumentException If any change is rejected by the implementation. + */ + Table alterTable(TableIdentifier ident, + List changes) throws NoSuchTableException; + + /** + * Apply {@link TableChange changes} to a table in the catalog. + *

+ * Implementations may reject the requested changes. If any change is rejected, none of the + * changes should be applied to the table. + * + * @param ident a table identifier + * @param changes a list of changes to apply to the table + * @return updated metadata for the table + * @throws NoSuchTableException If the table doesn't exist. + * @throws IllegalArgumentException If any change is rejected by the implementation. + */ + default Table alterTable(TableIdentifier ident, + TableChange... changes) throws NoSuchTableException { + return alterTable(ident, Arrays.asList(changes)); + } + + /** + * Drop a table in the catalog. + * + * @param ident a table identifier + * @return true if a table was deleted, false if no table exists for the identifier + */ + boolean dropTable(TableIdentifier ident); +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/TableChange.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/TableChange.java new file mode 100644 index 000000000000..3a8ba5e00b39 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/TableChange.java @@ -0,0 +1,182 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalog.v2; + +import org.apache.spark.sql.types.DataType; + +/** + * TableChange subclasses represent requested changes to a table. These are passed to + * {@link TableCatalog#alterTable}. For example, + *

+ *   import TableChange._
+ *   val catalog = source.asInstanceOf[TableSupport].catalog()
+ *   catalog.alterTable(ident,
+ *       addColumn("x", IntegerType),
+ *       renameColumn("a", "b"),
+ *       deleteColumn("c")
+ *     )
+ * 
+ */ +public interface TableChange { + + /** + * Create a TableChange for adding a top-level column to a table. + *

+ * Because "." may be interpreted as a field path separator or may be used in field names, it is + * not allowed in names passed to this method. To add to nested types or to add fields with + * names that contain ".", use {@link #addColumn(String, String, DataType)}. + * + * @param name the new top-level column name + * @param dataType the new column's data type + * @return a TableChange for the addition + */ + static TableChange addColumn(String name, DataType dataType) { + return new AddColumn(null, name, dataType); + } + + /** + * Create a TableChange for adding a nested column to a table. + *

+ * The parent name is used to find the parent struct type where the nested field will be added. + * If the parent name is null, the new column will be added to the root as a top-level column. + * If parent identifies a struct, a new column is added to that struct. If it identifies a list, + * the column is added to the list element struct, and if it identifies a map, the new column is + * added to the map's value struct. + *

+ * The given name is used to name the new column and names containing "." are not handled + * differently. + * + * @param parent the new field's parent + * @param name the new field name + * @param dataType the new field's data type + * @return a TableChange for the addition + */ + static TableChange addColumn(String parent, String name, DataType dataType) { + return new AddColumn(parent, name, dataType); + } + + /** + * Create a TableChange for renaming a field. + *

+ * The name is used to find the field to rename. The new name will replace the name of the type. + * For example, renameColumn("a.b.c", "x") should produce column a.b.x. + * + * @param name the current field name + * @param newName the new name + * @return a TableChange for the rename + */ + static TableChange renameColumn(String name, String newName) { + return new RenameColumn(name, newName); + } + + /** + * Create a TableChange for updating the type of a field. + *

+ * The name is used to find the field to update. + * + * @param name the field name + * @param newDataType the new data type + * @return a TableChange for the update + */ + static TableChange updateColumn(String name, DataType newDataType) { + return new UpdateColumn(name, newDataType); + } + + /** + * Create a TableChange for deleting a field from a table. + * + * @param name the name of the field to delete + * @return a TableChange for the delete + */ + static TableChange deleteColumn(String name) { + return new DeleteColumn(name); + } + + final class AddColumn implements TableChange { + private final String parent; + private final String name; + private final DataType dataType; + + private AddColumn(String parent, String name, DataType dataType) { + this.parent = parent; + this.name = name; + this.dataType = dataType; + } + + public String parent() { + return parent; + } + + public String name() { + return name; + } + + public DataType type() { + return dataType; + } + } + + final class RenameColumn implements TableChange { + private final String name; + private final String newName; + + private RenameColumn(String name, String newName) { + this.name = name; + this.newName = newName; + } + + public String name() { + return name; + } + + public String newName() { + return newName; + } + } + + final class UpdateColumn implements TableChange { + private final String name; + private final DataType newDataType; + + private UpdateColumn(String name, DataType newDataType) { + this.name = name; + this.newDataType = newDataType; + } + + public String name() { + return name; + } + + public DataType newDataType() { + return newDataType; + } + } + + final class DeleteColumn implements TableChange { + private final String name; + + private DeleteColumn(String name) { + this.name = name; + } + + public String name() { + return name; + } + } + +} diff --git a/sql/catalyst/src/test/java/org/apache/spark/sql/catalog/v2/CaseInsensitiveStringMapSuite.java b/sql/catalyst/src/test/java/org/apache/spark/sql/catalog/v2/CaseInsensitiveStringMapSuite.java new file mode 100644 index 000000000000..0d869108fa7d --- /dev/null +++ b/sql/catalyst/src/test/java/org/apache/spark/sql/catalog/v2/CaseInsensitiveStringMapSuite.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalog.v2; + +import org.junit.Assert; +import org.junit.Test; + +import java.util.HashSet; +import java.util.Set; + +public class CaseInsensitiveStringMapSuite { + @Test + public void testPutAndGet() { + CaseInsensitiveStringMap options = CaseInsensitiveStringMap.empty(); + options.put("kEy", "valUE"); + + Assert.assertEquals("Should return correct value for lower-case key", + "valUE", options.get("key")); + Assert.assertEquals("Should return correct value for upper-case key", + "valUE", options.get("KEY")); + } + + @Test + public void testKeySet() { + CaseInsensitiveStringMap options = CaseInsensitiveStringMap.empty(); + options.put("kEy", "valUE"); + + Set expectedKeySet = new HashSet<>(); + expectedKeySet.add("key"); + + Assert.assertEquals("Should return lower-case key set", expectedKeySet, options.keySet()); + } +} diff --git a/sql/catalyst/src/test/java/org/apache/spark/sql/catalog/v2/CatalogLoadingSuite.java b/sql/catalyst/src/test/java/org/apache/spark/sql/catalog/v2/CatalogLoadingSuite.java new file mode 100644 index 000000000000..62e26af7f0c6 --- /dev/null +++ b/sql/catalyst/src/test/java/org/apache/spark/sql/catalog/v2/CatalogLoadingSuite.java @@ -0,0 +1,184 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalog.v2; + +import org.apache.spark.SparkException; +import org.apache.spark.sql.internal.SQLConf; +import org.junit.Assert; +import org.junit.Test; + +import java.util.concurrent.Callable; + +public class CatalogLoadingSuite { + @Test + public void testLoad() throws SparkException { + SQLConf conf = new SQLConf(); + conf.setConfString("spark.sql.catalog.test-name", TestCatalogProvider.class.getCanonicalName()); + + CatalogProvider provider = Catalogs.load("test-name", conf); + Assert.assertNotNull("Should instantiate a non-null provider", provider); + Assert.assertEquals("Provider should have correct implementation", + TestCatalogProvider.class, provider.getClass()); + + TestCatalogProvider testProvider = (TestCatalogProvider) provider; + Assert.assertEquals("Options should contain only one key", 1, testProvider.options.size()); + Assert.assertEquals("Options should contain correct catalog name", + "test-name", testProvider.options.get("name")); + } + + @Test + public void testInitializationOptions() throws SparkException { + SQLConf conf = new SQLConf(); + conf.setConfString("spark.sql.catalog.test-name", TestCatalogProvider.class.getCanonicalName()); + conf.setConfString("spark.sql.catalog.test-name.name", "overwritten"); + conf.setConfString("spark.sql.catalog.test-name.kEy", "valUE"); + + CatalogProvider provider = Catalogs.load("test-name", conf); + Assert.assertNotNull("Should instantiate a non-null provider", provider); + Assert.assertEquals("Provider should have correct implementation", + TestCatalogProvider.class, provider.getClass()); + + TestCatalogProvider testProvider = (TestCatalogProvider) provider; + + Assert.assertEquals("Options should contain only two keys", 2, testProvider.options.size()); + Assert.assertEquals("Options should contain correct catalog name", + "test-name", testProvider.options.get("name")); + Assert.assertEquals("Options should contain correct value for key", + "valUE", testProvider.options.get("key")); + } + + @Test + public void testLoadWithoutConfig() { + SQLConf conf = new SQLConf(); + + SparkException exc = intercept(SparkException.class, () -> Catalogs.load("missing", conf)); + + Assert.assertTrue("Should complain that implementation is not configured", + exc.getMessage().contains("provider not found: spark.sql.catalog.missing is not defined")); + Assert.assertTrue("Should identify the catalog by name", exc.getMessage().contains("missing")); + } + + @Test + public void testLoadMissingClass() { + SQLConf conf = new SQLConf(); + conf.setConfString("spark.sql.catalog.missing", "com.example.NoSuchCatalogProvider"); + + SparkException exc = intercept(SparkException.class, () -> Catalogs.load("missing", conf)); + + Assert.assertTrue("Should complain that the class is not found", + exc.getMessage().contains("Cannot find catalog provider class")); + Assert.assertTrue("Should identify the catalog by name", exc.getMessage().contains("missing")); + Assert.assertTrue("Should identify the missing class", + exc.getMessage().contains("com.example.NoSuchCatalogProvider")); + } + + @Test + public void testLoadNonCatalogProvider() { + SQLConf conf = new SQLConf(); + String invalidClassName = InvalidCatalogProvider.class.getCanonicalName(); + conf.setConfString("spark.sql.catalog.invalid", invalidClassName); + + SparkException exc = intercept(SparkException.class, () -> Catalogs.load("invalid", conf)); + + Assert.assertTrue("Should complain that class does not implement CatalogProvider", + exc.getMessage().contains("does not implement CatalogProvider")); + Assert.assertTrue("Should identify the catalog by name", exc.getMessage().contains("invalid")); + Assert.assertTrue("Should identify the class", exc.getMessage().contains(invalidClassName)); + } + + @Test + public void testLoadConstructorFailureCatalogProvider() { + SQLConf conf = new SQLConf(); + String invalidClassName = ConstructorFailureCatalogProvider.class.getCanonicalName(); + conf.setConfString("spark.sql.catalog.invalid", invalidClassName); + + RuntimeException exc = intercept(RuntimeException.class, () -> Catalogs.load("invalid", conf)); + + Assert.assertTrue("Should have expected error message", + exc.getMessage().contains("Expected failure")); + } + + @Test + public void testLoadAccessErrorCatalogProvider() { + SQLConf conf = new SQLConf(); + String invalidClassName = AccessErrorCatalogProvider.class.getCanonicalName(); + conf.setConfString("spark.sql.catalog.invalid", invalidClassName); + + SparkException exc = intercept(SparkException.class, () -> Catalogs.load("invalid", conf)); + + Assert.assertTrue("Should complain that no public constructor is provided", + exc.getMessage().contains("Failed to call public no-arg constructor for catalog")); + Assert.assertTrue("Should identify the catalog by name", exc.getMessage().contains("invalid")); + Assert.assertTrue("Should identify the class", exc.getMessage().contains(invalidClassName)); + } + + @SuppressWarnings("unchecked") + public static E intercept(Class expected, Callable callable) { + try { + callable.call(); + Assert.fail("No exception was thrown, expected: " + + expected.getName()); + } catch (Exception actual) { + try { + Assert.assertEquals(expected, actual.getClass()); + return (E) actual; + } catch (AssertionError e) { + e.addSuppressed(actual); + throw e; + } + } + // Compiler doesn't catch that Assert.fail will always throw an exception. + throw new UnsupportedOperationException("[BUG] Should not reach this statement"); + } +} + +class TestCatalogProvider implements CatalogProvider { + CaseInsensitiveStringMap options = null; + + TestCatalogProvider() { + } + + @Override + public void initialize(CaseInsensitiveStringMap options) { + this.options = options; + } +} + +class ConstructorFailureCatalogProvider implements CatalogProvider { // fails in its constructor + ConstructorFailureCatalogProvider() { + throw new RuntimeException("Expected failure."); + } + + @Override + public void initialize(CaseInsensitiveStringMap options) { + } +} + +class AccessErrorCatalogProvider implements CatalogProvider { // no public constructor + private AccessErrorCatalogProvider() { + } + + @Override + public void initialize(CaseInsensitiveStringMap options) { + } +} + +class InvalidCatalogProvider { // doesn't implement CatalogProvider + public void initialize(CaseInsensitiveStringMap options) { + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index d9278d8cd23d..a4c8de6afceb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -21,6 +21,7 @@ import java.io.Closeable import java.util.concurrent.atomic.AtomicReference import scala.collection.JavaConverters._ +import scala.collection.mutable import scala.reflect.runtime.universe.TypeTag import scala.util.control.NonFatal @@ -31,6 +32,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd} import org.apache.spark.sql.catalog.Catalog +import org.apache.spark.sql.catalog.v2.{CatalogProvider, Catalogs} import org.apache.spark.sql.catalyst._ import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.encoders._ @@ -610,6 +612,12 @@ class SparkSession private( */ @transient lazy val catalog: Catalog = new CatalogImpl(self) + @transient private lazy val catalogs = new mutable.HashMap[String, CatalogProvider]() + + private[sql] def catalog(name: String): CatalogProvider = synchronized { + catalogs.getOrElseUpdate(name, Catalogs.load(name, sessionState.conf)) + } + /** * Returns the specified table/view as a `DataFrame`. * From 622180a50e05b4d968380824f5dbbe5f89e42422 Mon Sep 17 00:00:00 2001 From: Ryan Blue Date: Wed, 15 Aug 2018 14:03:13 -0700 Subject: [PATCH 2/3] SPARK-24252: Add PartitionTransform to replace Expression. Expression is internal and should not be used in public APIs. To avoid using Expression in the TableCatalog API, this commit adds a small set of transformations that are used to communicate partitioning to catalog implementations. This also adds an apply transformation that passes the name of a transform instead of a Transform class. This can be used to pass transforms that are unknown to Spark to the underlying catalog implementation. --- .../sql/catalog/v2/PartitionTransform.java | 49 ++++ .../apache/spark/sql/catalog/v2/Table.java | 7 +- .../spark/sql/catalog/v2/TableCatalog.java | 4 +- .../spark/sql/catalog/v2/Transforms.java | 223 ++++++++++++++++++ 4 files changed, 277 insertions(+), 6 deletions(-) create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/PartitionTransform.java create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/Transforms.java diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/PartitionTransform.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/PartitionTransform.java new file mode 100644 index 000000000000..117c99a42eb8 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/PartitionTransform.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalog.v2; + +/** + * A logical transformation function. + *

+ * This does not support applying transformations; it only communicates the type of transformation + * and its input column references. + *

+ * This interface is used to pass partitioning transformations to v2 catalog implementations. For + * example a table may partition data by the date of a timestamp column, ts, using + * date(ts). This is similar to org.apache.spark.sql.sources.Filter, which is used to + * pass boolean filter expressions to data source implementations. + *

+ * To use data values directly as partition values, use the "identity" transform: + * identity(col). Identity partition transforms are the only transforms used by Hive. + * For Hive tables, SQL statements produce data columns that are used without modification to + * partition the remaining data columns. + *

+ * Table formats other than Hive can use partition transforms to automatically derive partition + * values from rows and to transform data predicates to partition predicates. + */ +public interface PartitionTransform { + /** + * The name of this transform. + */ + String name(); + + /** + * The data columns that are referenced by this transform. + */ + String[] references(); +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/Table.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/Table.java index 30a20f27b8c6..644f7d474fc5 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/Table.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/Table.java @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalog.v2; -import org.apache.spark.sql.catalyst.expressions.Expression; import org.apache.spark.sql.types.StructType; import java.util.List; @@ -40,8 +39,8 @@ public interface Table { StructType schema(); /** - * Return the table partitioning expressions. - * @return this table's partitioning expressions + * Return the table partitioning transforms. + * @return this table's partitioning transforms */ - List partitionExpressions(); + List partitioning(); } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/TableCatalog.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/TableCatalog.java index 539beb0c39c5..8b9c89b509dc 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/TableCatalog.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/TableCatalog.java @@ -20,7 +20,6 @@ import org.apache.spark.sql.catalyst.TableIdentifier; import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException; -import org.apache.spark.sql.catalyst.expressions.Expression; import org.apache.spark.sql.types.StructType; import java.util.Arrays; @@ -89,10 +88,11 @@ default Table createTable(TableIdentifier ident, * @param properties a string map of table properties * @return metadata for the new table * @throws TableAlreadyExistsException If a table already exists for the identifier + * @throws UnsupportedOperationException If a requested partition transform is not supported */ Table createTable(TableIdentifier ident, StructType schema, - List partitions, + List partitions, Map properties) throws TableAlreadyExistsException; /** diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/Transforms.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/Transforms.java new file mode 100644 index 000000000000..561c1c539712 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/Transforms.java @@ -0,0 +1,223 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalog.v2; + +/** + * A standard set of transformations that are passed to data sources during table creation. + * + * @see PartitionTransform + */ +public class Transforms { + private Transforms() { + } + + /** + * Create a transform for a column with the given name. + *

+ * This transform is used to pass named transforms that are not known to Spark. + * + * @param transform a name of the transform to apply to the column + * @param colName a column name + * @return an Apply transform for the column + */ + public static PartitionTransform apply(String transform, String colName) { + if ("identity".equals(transform)) { + return identity(colName); + } else if ("year".equals(transform)) { + return year(colName); + } else if ("month".equals(transform)) { + return month(colName); + } else if ("date".equals(transform)) { + return date(colName); + } else if ("hour".equals(transform)) { + return hour(colName); + } + + // unknown transform names are passed to sources with Apply + return new Apply(transform, colName); + } + + /** + * Create an identity transform for a column name. + * + * @param colName a column name + * @return an Identity transform for the column + */ + public static Identity identity(String colName) { + return new Identity(colName); + } + + /** + * Create a bucket transform for a column name with the given number of buckets. + * + * @param colName a column name + * @param numBuckets the number of buckets + * @return a BucketBy transform for the column + */ + public static Bucket bucketBy(String colName, int numBuckets) { + return new Bucket(colName, numBuckets); + } + + /** + * Create a year transform for a column name. + *

+ * The corresponding column should be a timestamp or date column. + * + * @param colName a column name + * @return a Year transform for the column + */ + public static Year year(String colName) { + return new Year(colName); + } + + /** + * Create a month transform for a column name. + *

+ * The corresponding column should be a timestamp or date column. + * + * @param colName a column name + * @return a Month transform for the column + */ + public static Month month(String colName) { + return new Month(colName); + } + + /** + * Create a date transform for a column name. + *

+ * The corresponding column should be a timestamp or date column. + * + * @param colName a column name + * @return a Date transform for the column + */ + public static Date date(String colName) { + return new Date(colName); + } + + /** + * Create a date and hour transform for a column name. + *

+ * The corresponding column should be a timestamp column. + * + * @param colName a column name + * @return a DateAndHour transform for the column + */ + public static DateAndHour hour(String colName) { + return new DateAndHour(colName); + } + + private abstract static class SingleColumnTransform implements PartitionTransform { + private final String[] colNames; + + private SingleColumnTransform(String colName) { + this.colNames = new String[] { colName }; + } + + @Override + public String[] references() { + return colNames; + } + } + + public static final class Identity extends SingleColumnTransform { + private Identity(String colName) { + super(colName); + } + + @Override + public String name() { + return "identity"; + } + } + + public static final class Bucket extends SingleColumnTransform { + private final int numBuckets; + + private Bucket(String colName, int numBuckets) { + super(colName); + this.numBuckets = numBuckets; + } + + @Override + public String name() { + return "bucket"; + } + + public int numBuckets() { + return numBuckets; + } + } + + public static final class Year extends SingleColumnTransform { + private Year(String colName) { + super(colName); + } + + @Override + public String name() { + return "year"; + } + } + + public static final class Month extends SingleColumnTransform { + private Month(String colName) { + super(colName); + } + + @Override + public String name() { + return "month"; + } + } + + public static final class Date extends SingleColumnTransform { + private Date(String colName) { + super(colName); + } + + @Override + public String name() { + return "date"; + } + } + + public static final class DateAndHour extends SingleColumnTransform { + private DateAndHour(String colName) { + super(colName); + } + + @Override + public String name() { + return "hour"; + } + } + + public static final class Apply extends SingleColumnTransform { + private final String transformName; + + private Apply(String transformName, String colName) { + super(colName); + this.transformName = transformName; + } + + @Override + public String name() { + return transformName; + } + } +} From e50d94bbc205369969e0c1707bda057ee99f0007 Mon Sep 17 00:00:00 2001 From: Ryan Blue Date: Wed, 25 Jul 2018 11:11:45 -0700 Subject: [PATCH 3/3] SPARK-24923: Add v2 CTAS and RTAS support. This uses the catalog API introduced in SPARK-24252 to implement CTAS and RTAS plans. --- .../sql/catalyst/analysis/NamedRelation.scala | 3 + .../plans/logical/basicLogicalOperators.scala | 34 +++- .../spark/sql/sources/v2/ReadSupport.java | 2 +- .../spark/sql/sources/v2/WriteSupport.java | 2 +- .../apache/spark/sql/DataFrameReader.scala | 17 +- .../apache/spark/sql/DataFrameWriter.scala | 47 +++++- .../datasources/v2/DataSourceV2Relation.scala | 115 ++++++------- .../datasources/v2/DataSourceV2ScanExec.scala | 5 +- .../datasources/v2/DataSourceV2Strategy.scala | 36 ++++- .../v2/DataSourceV2StringFormat.scala | 11 +- .../v2/WriteToDataSourceV2Exec.scala | 92 ++++++++++- .../sources/v2/DataSourceV2Implicits.scala | 151 ++++++++++++++++++ 12 files changed, 420 insertions(+), 95 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/sources/v2/DataSourceV2Implicits.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NamedRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NamedRelation.scala index ad201f947b67..2e72d13cb923 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NamedRelation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NamedRelation.scala @@ -17,8 +17,11 @@ package org.apache.spark.sql.catalyst.analysis +import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan trait NamedRelation extends LogicalPlan { def name: String + + def output: Seq[AttributeReference] } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 7ff83a9be362..510bb62b55c1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -17,7 +17,8 @@ package org.apache.spark.sql.catalyst.plans.logical -import org.apache.spark.sql.catalyst.{AliasIdentifier} +import org.apache.spark.sql.catalog.v2.{PartitionTransform, TableCatalog} +import org.apache.spark.sql.catalyst.{AliasIdentifier, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, NamedRelation} import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable} import org.apache.spark.sql.catalyst.expressions._ @@ -384,6 +385,37 @@ object AppendData { } } +/** + * Create a new table from a select query. + */ +case class CreateTableAsSelect( + catalog: TableCatalog, + table: TableIdentifier, + partitioning: Seq[PartitionTransform], + query: LogicalPlan, + writeOptions: Map[String, String], + ignoreIfExists: Boolean) extends LogicalPlan { + + override def children: Seq[LogicalPlan] = Seq(query) + override def output: Seq[Attribute] = Seq.empty + override lazy val resolved = true +} + +/** + * Replace a table with the results of a select query. + */ +case class ReplaceTableAsSelect( + catalog: TableCatalog, + table: TableIdentifier, + partitioning: Seq[PartitionTransform], + query: LogicalPlan, + writeOptions: Map[String, String]) extends LogicalPlan { + + override def children: Seq[LogicalPlan] = Seq(query) + override def output: Seq[Attribute] = Seq.empty + override lazy val resolved = true +} + /** * Insert some data into a table. Note that this plan is unresolved and has to be replaced by the * concrete implementations during analysis. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupport.java index 80ac08ee5ff5..c99d803822f2 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupport.java @@ -27,7 +27,7 @@ * provide data reading ability and scan the data from the data source. */ @InterfaceStability.Evolving -public interface ReadSupport extends DataSourceV2 { +public interface ReadSupport { /** * Creates a {@link DataSourceReader} to scan the data from this data source. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/WriteSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/WriteSupport.java index 048787a7a0a0..16484d4b8439 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/WriteSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/WriteSupport.java @@ -29,7 +29,7 @@ * provide data writing ability and save the data to the data source. */ @InterfaceStability.Evolving -public interface WriteSupport extends DataSourceV2 { +public interface WriteSupport { /** * Creates an optional {@link DataSourceWriter} to save the data to this data source. Data diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 9bd113419ae4..9bd5f348db3a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -37,7 +37,7 @@ import org.apache.spark.sql.execution.datasources.jdbc._ import org.apache.spark.sql.execution.datasources.json.TextInputJsonDataSource import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils -import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, ReadSupport} +import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, DataSourceV2Implicits, ReadSupport} import org.apache.spark.sql.types.{StringType, StructType} import org.apache.spark.unsafe.types.UTF8String @@ -191,6 +191,21 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { "read files of Hive data source directly.") } + import DataSourceV2Implicits._ + + extraOptions.get("catalog") match { + case Some(catalogName) if extraOptions.get(DataSourceOptions.TABLE_KEY).isDefined => + val catalog = sparkSession.catalog(catalogName).asTableCatalog + val options = extraOptions.toMap + val identifier = options.table.get + + return Dataset.ofRows(sparkSession, + DataSourceV2Relation.create( + catalogName, identifier, catalog.loadTable(identifier), options)) + + case _ => + } + val cls = DataSource.lookupDataSource(source, sparkSession.sessionState.conf) if (classOf[DataSourceV2].isAssignableFrom(cls)) { val ds = cls.newInstance().asInstanceOf[DataSourceV2] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 650c91790a75..5bec85e599ec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -25,7 +25,7 @@ import org.apache.spark.annotation.InterfaceStability import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, UnresolvedRelation} import org.apache.spark.sql.catalyst.catalog._ -import org.apache.spark.sql.catalyst.plans.logical.{AppendData, InsertIntoTable, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.{AppendData, CreateTableAsSelect, InsertIntoTable, LogicalPlan, ReplaceTableAsSelect} import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.{CreateTable, DataSource, LogicalRelation} @@ -236,6 +236,51 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { assertNotBucketed("save") + import DataSourceV2Implicits._ + + extraOptions.get("catalog") match { + case Some(catalogName) if extraOptions.get(DataSourceOptions.TABLE_KEY).isDefined => + val catalog = df.sparkSession.catalog(catalogName).asTableCatalog + val options = extraOptions.toMap + val identifier = options.table.get + val exists = catalog.tableExists(identifier) + + (exists, mode) match { + case (true, SaveMode.ErrorIfExists) => + throw new AnalysisException(s"Table already exists: ${identifier.quotedString}") + + case (true, SaveMode.Overwrite) => + runCommand(df.sparkSession, "save") { + ReplaceTableAsSelect(catalog, identifier, Seq.empty, df.logicalPlan, options) + } + + case (true, SaveMode.Append) => + val relation = DataSourceV2Relation.create( + catalogName, identifier, catalog.loadTable(identifier), options) + + runCommand(df.sparkSession, "save") { + AppendData.byName(relation, df.logicalPlan) + } + + case (false, SaveMode.Append) => + throw new AnalysisException(s"Table does not exist: ${identifier.quotedString}") + + case (false, SaveMode.ErrorIfExists) | + (false, SaveMode.Ignore) | + (false, SaveMode.Overwrite) => + + runCommand(df.sparkSession, "save") { + CreateTableAsSelect(catalog, identifier, Seq.empty, df.logicalPlan, options, + ignoreIfExists = mode == SaveMode.Ignore) + } + + case _ => + return // table exists and mode is ignore + } + + case _ => + } + val cls = DataSource.lookupDataSource(source, df.sparkSession.sessionState.conf) if (classOf[DataSourceV2].isAssignableFrom(cls)) { val source = cls.newInstance().asInstanceOf[DataSourceV2] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala index a4bfc861cc9a..ebe25eb8276c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala @@ -17,20 +17,18 @@ package org.apache.spark.sql.execution.datasources.v2 -import java.util.UUID - -import scala.collection.JavaConverters._ - -import org.apache.spark.sql.{AnalysisException, SaveMode} +import org.apache.spark.sql.catalog.v2.Table import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, NamedRelation} import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression} import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} import org.apache.spark.sql.sources.DataSourceRegister import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, ReadSupport, WriteSupport} +import org.apache.spark.sql.sources.v2.DataSourceV2Implicits._ import org.apache.spark.sql.sources.v2.reader.{DataSourceReader, SupportsReportStatistics} import org.apache.spark.sql.sources.v2.writer.DataSourceWriter import org.apache.spark.sql.types.StructType +import org.apache.spark.util.Utils /** * A logical plan representing a data source v2 scan. @@ -48,10 +46,10 @@ case class DataSourceV2Relation( userSpecifiedSchema: Option[StructType] = None) extends LeafNode with MultiInstanceRelation with NamedRelation with DataSourceV2StringFormat { - import DataSourceV2Relation._ + override def sourceName: String = source.name override def name: String = { - tableIdent.map(_.unquotedString).getOrElse(s"${source.name}:unknown") + tableIdent.map(_.unquotedString).getOrElse(s"$sourceName:unknown") } override def pushedFilters: Seq[Expression] = Seq.empty @@ -62,7 +60,7 @@ case class DataSourceV2Relation( def newWriter(): DataSourceWriter = source.createWriter(options, schema) - override def computeStats(): Statistics = newReader match { + override def computeStats(): Statistics = newReader() match { case r: SupportsReportStatistics => Statistics(sizeInBytes = r.getStatistics.sizeInBytes().orElse(conf.defaultSizeInBytes)) case _ => @@ -74,6 +72,43 @@ case class DataSourceV2Relation( } } +/** + * A logical plan representing a data source v2 table. + * + * @param ident The table's TableIdentifier. + * @param table The table. + * @param output The output attributes of the table. + * @param options The options for this scan or write. + */ +case class TableV2Relation( + catalogName: String, + ident: TableIdentifier, + table: Table, + output: Seq[AttributeReference], + options: Map[String, String]) + extends LeafNode with MultiInstanceRelation with NamedRelation { + + import org.apache.spark.sql.sources.v2.DataSourceV2Implicits._ + + override def name: String = ident.unquotedString + + override def simpleString: String = + s"RelationV2 $name ${Utils.truncatedString(output, "[", ", ", "]")}" + + def newReader(): DataSourceReader = table.createReader(options) + + override def computeStats(): Statistics = newReader() match { + case r: SupportsReportStatistics => + Statistics(sizeInBytes = r.getStatistics.sizeInBytes().orElse(conf.defaultSizeInBytes)) + case _ => + Statistics(sizeInBytes = conf.defaultSizeInBytes) + } + + override def newInstance(): TableV2Relation = { + copy(output = output.map(_.newInstance())) + } +} + /** * A specialization of [[DataSourceV2Relation]] with the streaming bit set to true. * @@ -88,6 +123,8 @@ case class StreamingDataSourceV2Relation( reader: DataSourceReader) extends LeafNode with MultiInstanceRelation with DataSourceV2StringFormat { + override def sourceName: String = source.name + override def isStreaming: Boolean = true override def simpleString: String = "Streaming RelationV2 " + metadataString @@ -116,68 +153,22 @@ case class StreamingDataSourceV2Relation( } object DataSourceV2Relation { - private implicit class SourceHelpers(source: DataSourceV2) { - def asReadSupport: ReadSupport = { - source match { - case support: ReadSupport => - support - case _ => - throw new AnalysisException(s"Data source is not readable: $name") - } - } - - def asWriteSupport: WriteSupport = { - source match { - case support: WriteSupport => - support - case _ => - throw new AnalysisException(s"Data source is not writable: $name") - } - } - - def name: String = { - source match { - case registered: DataSourceRegister => - registered.shortName() - case _ => - source.getClass.getSimpleName - } - } - - def createReader( - options: Map[String, String], - userSpecifiedSchema: Option[StructType]): DataSourceReader = { - val v2Options = new DataSourceOptions(options.asJava) - userSpecifiedSchema match { - case Some(s) => - asReadSupport.createReader(s, v2Options) - case _ => - asReadSupport.createReader(v2Options) - } - } - - def createWriter( - options: Map[String, String], - schema: StructType): DataSourceWriter = { - val v2Options = new DataSourceOptions(options.asJava) - asWriteSupport.createWriter(UUID.randomUUID.toString, schema, SaveMode.Append, v2Options).get - } - } - def create( source: DataSourceV2, options: Map[String, String], tableIdent: Option[TableIdentifier] = None, - userSpecifiedSchema: Option[StructType] = None): DataSourceV2Relation = { + userSpecifiedSchema: Option[StructType] = None): NamedRelation = { val reader = source.createReader(options, userSpecifiedSchema) - val ident = tableIdent.orElse(tableFromOptions(options)) + val ident = tableIdent.orElse(options.table) DataSourceV2Relation( source, reader.readSchema().toAttributes, options, ident, userSpecifiedSchema) } - private def tableFromOptions(options: Map[String, String]): Option[TableIdentifier] = { - options - .get(DataSourceOptions.TABLE_KEY) - .map(TableIdentifier(_, options.get(DataSourceOptions.DATABASE_KEY))) + def create( + catalogName: String, + ident: TableIdentifier, + table: Table, + options: Map[String, String]): NamedRelation = { + TableV2Relation(catalogName, ident, table, table.schema.toAttributes, options) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala index c8494f97f176..f1071147ddf6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala @@ -26,7 +26,6 @@ import org.apache.spark.sql.catalyst.plans.physical import org.apache.spark.sql.catalyst.plans.physical.SinglePartition import org.apache.spark.sql.execution.{ColumnarBatchScan, LeafExecNode, WholeStageCodegenExec} import org.apache.spark.sql.execution.streaming.continuous._ -import org.apache.spark.sql.sources.v2.DataSourceV2 import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReader import org.apache.spark.sql.vectorized.ColumnarBatch @@ -36,7 +35,7 @@ import org.apache.spark.sql.vectorized.ColumnarBatch */ case class DataSourceV2ScanExec( output: Seq[AttributeReference], - @transient source: DataSourceV2, + @transient sourceName: String, @transient options: Map[String, String], @transient pushedFilters: Seq[Expression], @transient reader: DataSourceReader) @@ -52,7 +51,7 @@ case class DataSourceV2ScanExec( } override def hashCode(): Int = { - Seq(output, source, options).hashCode() + Seq(output, sourceName, options).hashCode() } override def outputPartitioning: physical.Partitioning = reader match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index 6daaa4c65c33..b10b3eb83e32 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -20,9 +20,11 @@ package org.apache.spark.sql.execution.datasources.v2 import scala.collection.mutable import org.apache.spark.sql.{sources, Strategy} +import org.apache.spark.sql.catalog.v2.TableCatalog +import org.apache.spark.sql.catalyst.analysis.NamedRelation import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, AttributeSet, Expression} import org.apache.spark.sql.catalyst.planning.PhysicalOperation -import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LogicalPlan, Repartition} +import org.apache.spark.sql.catalyst.plans.logical.{AppendData, CreateTableAsSelect, LogicalPlan, Repartition, ReplaceTableAsSelect} import org.apache.spark.sql.execution.{FilterExec, ProjectExec, SparkPlan} import org.apache.spark.sql.execution.datasources.DataSourceStrategy import org.apache.spark.sql.execution.streaming.continuous.{ContinuousCoalesceExec, WriteToContinuousDataSource, WriteToContinuousDataSourceExec} @@ -31,6 +33,8 @@ import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReader object DataSourceV2Strategy extends Strategy { + import org.apache.spark.sql.sources.v2.DataSourceV2Implicits._ + /** * Pushes down filters to the data source reader * @@ -81,7 +85,7 @@ object DataSourceV2Strategy extends Strategy { // TODO: nested column pruning. private def pruneColumns( reader: DataSourceReader, - relation: DataSourceV2Relation, + relation: NamedRelation, exprs: Seq[Expression]): Seq[AttributeReference] = { reader match { case r: SupportsPushDownRequiredColumns => @@ -102,10 +106,15 @@ object DataSourceV2Strategy extends Strategy { } } - override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case PhysicalOperation(project, filters, relation: DataSourceV2Relation) => - val reader = relation.newReader() + case PhysicalOperation(project, filters, relation: NamedRelation) + if relation.isInstanceOf[DataSourceV2Relation] || relation.isInstanceOf[TableV2Relation] => + + val (reader, options, sourceName) = relation match { + case r: DataSourceV2Relation => (r.newReader(), r.options, r.sourceName) + case r: TableV2Relation => (r.newReader(), r.options, r.catalogName) + } + // `pushedFilters` will be pushed down and evaluated in the underlying data sources. // `postScanFilters` need to be evaluated after the scan. // `postScanFilters` and `pushedFilters` can overlap, e.g. the parquet row group filter. @@ -113,14 +122,14 @@ object DataSourceV2Strategy extends Strategy { val output = pruneColumns(reader, relation, project ++ postScanFilters) logInfo( s""" - |Pushing operators to ${relation.source.getClass} + |Pushing operators to $sourceName |Pushed Filters: ${pushedFilters.mkString(", ")} |Post-Scan Filters: ${postScanFilters.mkString(",")} |Output: ${output.mkString(", ")} """.stripMargin) val scan = DataSourceV2ScanExec( - output, relation.source, relation.options, pushedFilters, reader) + output, sourceName, options, pushedFilters, reader) val filterCondition = postScanFilters.reduceLeftOption(And) val withFilter = filterCondition.map(FilterExec(_, scan)).getOrElse(scan) @@ -131,7 +140,7 @@ object DataSourceV2Strategy extends Strategy { case r: StreamingDataSourceV2Relation => // ensure there is a projection, which will produce unsafe rows required by some operators ProjectExec(r.output, - DataSourceV2ScanExec(r.output, r.source, r.options, r.pushedFilters, r.reader)) :: Nil + DataSourceV2ScanExec(r.output, r.source.name, r.options, r.pushedFilters, r.reader)) :: Nil case WriteToDataSourceV2(writer, query) => WriteToDataSourceV2Exec(writer, planLater(query)) :: Nil @@ -139,6 +148,17 @@ object DataSourceV2Strategy extends Strategy { case AppendData(r: DataSourceV2Relation, query, _) => WriteToDataSourceV2Exec(r.newWriter(), planLater(query)) :: Nil + case AppendData(r: TableV2Relation, query, _) => + AppendDataExec(r.table, r.options, planLater(query)) :: Nil + + case CreateTableAsSelect(catalog, ident, partitioning, query, writeOptions, ignoreIfExists) => + CreateTableAsSelectExec(catalog, ident, partitioning, Map.empty, writeOptions, + planLater(query), ignoreIfExists) :: Nil + + case ReplaceTableAsSelect(catalog, ident, partitioning, query, writeOptions) => + ReplaceTableAsSelectExec(catalog, ident, partitioning, Map.empty, writeOptions, + planLater(query)) :: Nil + case WriteToContinuousDataSource(writer, query) => WriteToContinuousDataSourceExec(writer, planLater(query)) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala index 97e6c6d702ac..5c5a51c99c77 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala @@ -20,8 +20,6 @@ package org.apache.spark.sql.execution.datasources.v2 import org.apache.commons.lang3.StringUtils import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} -import org.apache.spark.sql.sources.DataSourceRegister -import org.apache.spark.sql.sources.v2.DataSourceV2 import org.apache.spark.util.Utils /** @@ -34,7 +32,7 @@ trait DataSourceV2StringFormat { * The instance of this data source implementation. Note that we only consider its class in * equals/hashCode, not the instance itself. */ - def source: DataSourceV2 + def sourceName: String /** * The output of the data source reader, w.r.t. column pruning. @@ -51,13 +49,6 @@ trait DataSourceV2StringFormat { */ def pushedFilters: Seq[Expression] - private def sourceName: String = source match { - case registered: DataSourceRegister => registered.shortName() - // source.getClass.getSimpleName can cause Malformed class name error, - // call safer `Utils.getSimpleName` instead - case _ => Utils.getSimpleName(source.getClass) - } - def metadataString: String = { val entries = scala.collection.mutable.ArrayBuffer.empty[(String, String)] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala index 59ebb9bc5431..c428ebe6d469 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala @@ -17,21 +17,21 @@ package org.apache.spark.sql.execution.datasources.v2 +import scala.collection.JavaConverters._ import scala.util.control.NonFatal import org.apache.spark.{SparkEnv, SparkException, TaskContext} import org.apache.spark.executor.CommitDeniedException import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD -import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} +import org.apache.spark.sql.catalog.v2.{PartitionTransform, Table, TableCatalog} +import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} +import org.apache.spark.sql.catalyst.analysis.{NoSuchTableException, TableAlreadyExistsException} import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.streaming.MicroBatchExecution import org.apache.spark.sql.sources.v2.writer._ -import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils /** @@ -44,14 +44,92 @@ case class WriteToDataSourceV2(writer: DataSourceWriter, query: LogicalPlan) ext override def output: Seq[Attribute] = Nil } +case class AppendDataExec( + table: Table, + writeOptions: Map[String, String], + plan: SparkPlan) extends V2TableWriteExec(writeOptions, plan) { + + override protected def doExecute(): RDD[InternalRow] = { + appendToTable(table) + } +} + +case class CreateTableAsSelectExec( + catalog: TableCatalog, + ident: TableIdentifier, + partitioning: Seq[PartitionTransform], + properties: Map[String, String], + writeOptions: Map[String, String], + plan: SparkPlan, + ifNotExists: Boolean) extends V2TableWriteExec(writeOptions, plan) { + + override protected def doExecute(): RDD[InternalRow] = { + if (catalog.tableExists(ident)) { + if (ifNotExists) { + return sparkContext.parallelize(Seq.empty, 1) + } + + throw new TableAlreadyExistsException(ident.database.getOrElse("null"), ident.table) + } + + Utils.tryWithSafeFinally({ + val table = catalog.createTable(ident, plan.schema, partitioning.asJava, properties.asJava) + appendToTable(table) + })(finallyBlock = { + catalog.dropTable(ident) + }) + } +} + +case class ReplaceTableAsSelectExec( + catalog: TableCatalog, + ident: TableIdentifier, + partitioning: Seq[PartitionTransform], + properties: Map[String, String], + writeOptions: Map[String, String], + plan: SparkPlan) extends V2TableWriteExec(writeOptions, plan) { + + override protected def doExecute(): RDD[InternalRow] = { + if (!catalog.tableExists(ident)) { + throw new NoSuchTableException(ident.database.getOrElse("null"), ident.table) + } + + catalog.dropTable(ident) + + Utils.tryWithSafeFinally({ + val table = catalog.createTable(ident, plan.schema, partitioning.asJava, properties.asJava) + appendToTable(table) + })(finallyBlock = { + catalog.dropTable(ident) + }) + } +} + +case class WriteToDataSourceV2Exec( + writer: DataSourceWriter, + plan: SparkPlan) extends V2TableWriteExec(Map.empty, plan) { + + override protected def doExecute(): RDD[InternalRow] = { + doAppend(writer) + } +} + /** - * The physical plan for writing data into data source v2. + * The base physical plan for writing data into data source v2. */ -case class WriteToDataSourceV2Exec(writer: DataSourceWriter, query: SparkPlan) extends SparkPlan { +abstract class V2TableWriteExec( + options: Map[String, String], + query: SparkPlan) extends SparkPlan { + import org.apache.spark.sql.sources.v2.DataSourceV2Implicits._ + override def children: Seq[SparkPlan] = Seq(query) override def output: Seq[Attribute] = Nil - override protected def doExecute(): RDD[InternalRow] = { + protected def appendToTable(table: Table): RDD[InternalRow] = { + doAppend(table.createWriter(options, query.schema)) + } + + protected def doAppend(writer: DataSourceWriter): RDD[InternalRow] = { val writeTask = writer.createWriterFactory() val useCommitCoordinator = writer.useCommitCoordinator val rdd = query.execute() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/v2/DataSourceV2Implicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/v2/DataSourceV2Implicits.scala new file mode 100644 index 000000000000..816f6eb2fc97 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/v2/DataSourceV2Implicits.scala @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2 + +import java.util.UUID + +import scala.collection.JavaConverters._ + +import org.apache.spark.sql.{AnalysisException, SaveMode} +import org.apache.spark.sql.catalog.v2.{CaseInsensitiveStringMap, CatalogProvider, Table, TableCatalog} +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.sources.DataSourceRegister +import org.apache.spark.sql.sources.v2.reader.DataSourceReader +import org.apache.spark.sql.sources.v2.writer.DataSourceWriter +import org.apache.spark.sql.types.StructType + +/** + * Implicit helper classes to make working with the v2 API in Scala easier. + */ +private[sql] object DataSourceV2Implicits { + implicit class CatalogHelper(catalog: CatalogProvider) { + def asTableCatalog: TableCatalog = catalog match { + case tableCatalog: TableCatalog => + tableCatalog + case _ => + throw new UnsupportedOperationException(s"Catalog $catalog does not support tables") + } + } + + implicit class TableHelper(table: Table) { + def asReadSupport: ReadSupport = { + table match { + case support: ReadSupport => + support + case _ => + throw new AnalysisException(s"Table is not readable: $table") + } + } + + def asWriteSupport: WriteSupport = { + table match { + case support: WriteSupport => + support + case _ => + throw new AnalysisException(s"Table is not writable: $table") + } + } + + def createReader( + options: Map[String, String], + userSpecifiedSchema: Option[StructType]): DataSourceReader = { + userSpecifiedSchema match { + case Some(schema) => + asReadSupport.createReader(schema, options.asDataSourceOptions) + case None => + asReadSupport.createReader(options.asDataSourceOptions) + } + } + + def createWriter( + options: Map[String, String], + schema: StructType): DataSourceWriter = { + asWriteSupport.createWriter( + UUID.randomUUID.toString, schema, SaveMode.Append, options.asDataSourceOptions).get + } + } + + implicit class SourceHelper(source: DataSourceV2) { + def asReadSupport: ReadSupport = { + source match { + case support: ReadSupport => + support + case _ => + throw new AnalysisException(s"Data source is not readable: $name") + } + } + + def asWriteSupport: WriteSupport = { + source match { + case support: WriteSupport => + support + case _ => + throw new AnalysisException(s"Data source is not writable: $name") + } + } + + def name: String = { + source match { + case registered: DataSourceRegister => + registered.shortName() + case _ => + source.getClass.getSimpleName + } + } + + def createReader( + options: Map[String, String], + userSpecifiedSchema: Option[StructType]): DataSourceReader = { + userSpecifiedSchema match { + case Some(schema) => + asReadSupport.createReader(schema, options.asDataSourceOptions) + case None => + asReadSupport.createReader(options.asDataSourceOptions) + } + } + + def createWriter( + options: Map[String, String], + schema: StructType): DataSourceWriter = { + val v2Options = new DataSourceOptions(options.asJava) + asWriteSupport.createWriter(UUID.randomUUID.toString, schema, SaveMode.Append, v2Options).get + } + } + + implicit class OptionsHelper(options: Map[String, String]) { + def asDataSourceOptions: DataSourceOptions = { + new DataSourceOptions(options.asJava) + } + + def asCaseInsensitiveMap: CaseInsensitiveStringMap = { + val map = CaseInsensitiveStringMap.empty() + map.putAll(options.asJava) + map + } + + def table: Option[TableIdentifier] = { + val map = asCaseInsensitiveMap + Option(map.get(DataSourceOptions.TABLE_KEY)) + .map(TableIdentifier(_, Option(map.get(DataSourceOptions.DATABASE_KEY)))) + } + + def paths: Array[String] = { + asDataSourceOptions.paths() + } + } +}