From 2ffed5fdc518c5832ca917722fb197c11f630d6d Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Mon, 15 Oct 2018 18:30:31 +0800 Subject: [PATCH 1/2] Address comments at 22379 --- R/pkg/R/functions.R | 11 ++++++-- R/pkg/tests/fulltests/test_sparkSQL.R | 2 ++ python/pyspark/sql/functions.py | 16 +++++++++++- ...SVUtils.scala => CSVExpressionUtils.scala} | 2 +- .../sql/catalyst/csv/CSVHeaderChecker.scala | 2 +- .../spark/sql/catalyst/csv/CSVOptions.scala | 2 +- .../sql/catalyst/csv/UnivocityParser.scala | 2 +- .../sql/catalyst/csv/CSVUtilsSuite.scala | 26 +++++++++---------- .../datasources/csv/CSVDataSource.scala | 3 +-- .../execution/datasources/csv/CSVUtils.scala | 4 +++ .../org/apache/spark/sql/functions.scala | 4 +-- .../apache/spark/sql/CsvFunctionsSuite.scala | 2 +- 12 files changed, 51 insertions(+), 25 deletions(-) rename sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/{CSVUtils.scala => CSVExpressionUtils.scala} (99%) diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 72ae3771211d..d2ca1d6c00bb 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -2223,12 +2223,19 @@ setMethod("from_json", signature(x = "Column", schema = "characterOrstructType") #' schema <- "city STRING, year INT" #' head(select(df, from_csv(df$csv, schema)))} #' @note from_csv since 3.0.0 -setMethod("from_csv", signature(x = "Column", schema = "character"), +setMethod("from_csv", signature(x = "Column", schema = "characterOrColumn"), function(x, schema, ...) { + if (class(schema) == "Column") { + jschema <- schema@jc + } else if (is.character(schema)) { + jschema <- callJStatic("org.apache.spark.sql.functions", "lit", schema) + } else { + stop("schema argument should be a column or character") + } options <- varargsToStrEnv(...) jc <- callJStatic("org.apache.spark.sql.functions", "from_csv", - x@jc, schema, options) + x@jc, jschema, options) column(jc) }) diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 9c790acf7d3b..5ad5d78d3ed1 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -1651,6 +1651,8 @@ test_that("column functions", { df <- as.DataFrame(list(list("col" = "1"))) c <- collect(select(df, alias(from_csv(df$col, "a INT"), "csv"))) expect_equal(c[[1]][[1]]$a, 1) + c <- collect(select(df, alias(from_csv(df$col, lit("a INT")), "csv"))) + expect_equal(c[[1]][[1]]$a, 1) # Test to_json(), from_json() df <- sql("SELECT array(named_struct('name', 'Bob'), named_struct('name', 'Alice')) as people") diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index b86fd32bca19..9a0ff159f189 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -25,9 +25,12 @@ if sys.version < "3": from itertools import imap as map +if sys.version >= '3': + basestring = str + from pyspark import since, SparkContext from pyspark.rdd import ignore_unicode_prefix, PythonEvalType -from pyspark.sql.column import Column, _to_java_column, _to_seq +from pyspark.sql.column import Column, _to_java_column, _to_seq, _create_column_from_literal from pyspark.sql.dataframe import DataFrame from pyspark.sql.types import StringType, DataType # Keep UserDefinedFunction import for backwards compatible import; moved in SPARK-22409 @@ -2693,9 +2696,20 @@ def from_csv(col, schema, options={}): >>> df = spark.createDataFrame(data, ("key", "value")) >>> df.select(from_csv(df.value, "a INT").alias("csv")).collect() [Row(csv=Row(a=1))] + >>> data = [(1, '1')] + >>> df = spark.createDataFrame(data, ("key", "value")) + >>> df.select(from_csv(df.value, lit("a INT")).alias("csv")).collect() + [Row(csv=Row(a=1))] """ sc = SparkContext._active_spark_context + if isinstance(schema, basestring): + schema = _create_column_from_literal(schema) + elif isinstance(schema, Column): + schema = _to_java_column(schema) + else: + raise TypeError("schema argument should be a column or string") + jc = sc._jvm.functions.from_csv(_to_java_column(col), schema, options) return Column(jc) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVExpressionUtils.scala similarity index 99% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVUtils.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVExpressionUtils.scala index 109325089ff8..22fd5b041b49 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVExpressionUtils.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.csv -object CSVUtils { +object CSVExpressionUtils { /** * Filter ignorable rows for CSV iterator (lines empty and starting with `comment`). * This is currently being used in CSV reading path and CSV schema inference. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVHeaderChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVHeaderChecker.scala index b1b010de9f2f..4665537dcdb7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVHeaderChecker.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVHeaderChecker.scala @@ -123,7 +123,7 @@ class CSVHeaderChecker( // Note: if there are only comments in the first block, the header would probably // be not extracted. if (options.headerFlag && isStartOfFile) { - CSVUtils.extractHeader(lines, options).foreach { header => + CSVExpressionUtils.extractHeader(lines, options).foreach { header => checkHeaderColumnNames(tokenizer.parseLine(header)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVOptions.scala index 1f39b20bb2f5..957238ff881d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVOptions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVOptions.scala @@ -83,7 +83,7 @@ class CSVOptions( } } - val delimiter = CSVUtils.toChar( + val delimiter = CSVExpressionUtils.toChar( parameters.getOrElse("sep", parameters.getOrElse("delimiter", ","))) val parseMode: ParseMode = parameters.get("mode").map(ParseMode.fromString).getOrElse(PermissiveMode) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala index f0890e0d36e8..e61771ae714e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala @@ -338,7 +338,7 @@ private[sql] object UnivocityParser { val options = parser.options - val filteredLines: Iterator[String] = CSVUtils.filterCommentAndEmpty(lines, options) + val filteredLines: Iterator[String] = CSVExpressionUtils.filterCommentAndEmpty(lines, options) val safeParser = new FailureSafeParser[String]( input => Seq(parser.parse(input)), diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVUtilsSuite.scala index dde46a97e673..0d4f1f6eaa67 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVUtilsSuite.scala @@ -19,42 +19,42 @@ package org.apache.spark.sql.catalyst.csv import org.apache.spark.SparkFunSuite -class CSVUtilsSuite extends SparkFunSuite { +class CSVExpressionUtilsSuite extends SparkFunSuite { test("Can parse escaped characters") { - assert(CSVUtils.toChar("""\t""") === '\t') - assert(CSVUtils.toChar("""\r""") === '\r') - assert(CSVUtils.toChar("""\b""") === '\b') - assert(CSVUtils.toChar("""\f""") === '\f') - assert(CSVUtils.toChar("""\"""") === '\"') - assert(CSVUtils.toChar("""\'""") === '\'') - assert(CSVUtils.toChar("""\u0000""") === '\u0000') - assert(CSVUtils.toChar("""\\""") === '\\') + assert(CSVExpressionUtils.toChar("""\t""") === '\t') + assert(CSVExpressionUtils.toChar("""\r""") === '\r') + assert(CSVExpressionUtils.toChar("""\b""") === '\b') + assert(CSVExpressionUtils.toChar("""\f""") === '\f') + assert(CSVExpressionUtils.toChar("""\"""") === '\"') + assert(CSVExpressionUtils.toChar("""\'""") === '\'') + assert(CSVExpressionUtils.toChar("""\u0000""") === '\u0000') + assert(CSVExpressionUtils.toChar("""\\""") === '\\') } test("Does not accept delimiter larger than one character") { val exception = intercept[IllegalArgumentException]{ - CSVUtils.toChar("ab") + CSVExpressionUtils.toChar("ab") } assert(exception.getMessage.contains("cannot be more than one character")) } test("Throws exception for unsupported escaped characters") { val exception = intercept[IllegalArgumentException]{ - CSVUtils.toChar("""\1""") + CSVExpressionUtils.toChar("""\1""") } assert(exception.getMessage.contains("Unsupported special character for delimiter")) } test("string with one backward slash is prohibited") { val exception = intercept[IllegalArgumentException]{ - CSVUtils.toChar("""\""") + CSVExpressionUtils.toChar("""\""") } assert(exception.getMessage.contains("Single backslash is prohibited")) } test("output proper error message for empty string") { val exception = intercept[IllegalArgumentException]{ - CSVUtils.toChar("") + CSVExpressionUtils.toChar("") } assert(exception.getMessage.contains("Delimiter cannot be empty string")) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala index b133bb5a57b2..9e7b45db9f28 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala @@ -35,7 +35,6 @@ import org.apache.spark.rdd.{BinaryFileRDD, RDD} import org.apache.spark.sql.{Dataset, Encoders, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.csv.{CSVHeaderChecker, CSVOptions, UnivocityParser} -import org.apache.spark.sql.catalyst.csv.CSVUtils.filterCommentAndEmpty import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.text.TextFileFormat import org.apache.spark.sql.types.StructType @@ -130,7 +129,7 @@ object TextInputCSVDataSource extends CSVDataSource { val header = CSVUtils.makeSafeHeader(firstRow, caseSensitive, parsedOptions) val sampled: Dataset[String] = CSVUtils.sample(csv, parsedOptions) val tokenRDD = sampled.rdd.mapPartitions { iter => - val filteredLines = filterCommentAndEmpty(iter, parsedOptions) + val filteredLines = CSVUtils.filterCommentAndEmpty(iter, parsedOptions) val linesWithoutHeader = CSVUtils.filterHeaderLine(filteredLines, maybeFirstLine.get, parsedOptions) val parser = new CsvParser(parsedOptions.asParserSettings) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala index 689555f53eb2..8b6945968181 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.datasources.csv import org.apache.spark.rdd.RDD import org.apache.spark.sql.Dataset +import org.apache.spark.sql.catalyst.csv.CSVExpressionUtils import org.apache.spark.sql.catalyst.csv.CSVOptions import org.apache.spark.sql.functions._ @@ -125,4 +126,7 @@ object CSVUtils { csv.sample(withReplacement = false, options.samplingRatio, 1) } } + + def filterCommentAndEmpty(iter: Iterator[String], options: CSVOptions): Iterator[String] = + CSVExpressionUtils.filterCommentAndEmpty(iter, options) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 466db861ddf3..8def9967cffb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3882,8 +3882,8 @@ object functions { * @group collection_funcs * @since 3.0.0 */ - def from_csv(e: Column, schema: String, options: java.util.Map[String, String]): Column = { - withExpr(new CsvToStructs(e.expr, lit(schema).expr, options.asScala.toMap)) + def from_csv(e: Column, schema: Column, options: java.util.Map[String, String]): Column = { + withExpr(new CsvToStructs(e.expr, schema.expr, options.asScala.toMap)) } // scalastyle:off line.size.limit diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala index 6ccd06be26a5..38a2143d6d0f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala @@ -31,7 +31,7 @@ class CsvFunctionsSuite extends QueryTest with SharedSQLContext { val schema = "a int" checkAnswer( - df.select(from_csv($"value", schema, Map[String, String]().asJava)), + df.select(from_csv($"value", lit(schema), Map[String, String]().asJava)), Row(Row(1)) :: Nil) } From a32bbcb44ae2ca9f5d329e96b7c33d37ed3208a0 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Mon, 15 Oct 2018 18:34:17 +0800 Subject: [PATCH 2/2] nit --- python/pyspark/sql/functions.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 9a0ff159f189..32d7f02f6188 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2696,7 +2696,6 @@ def from_csv(col, schema, options={}): >>> df = spark.createDataFrame(data, ("key", "value")) >>> df.select(from_csv(df.value, "a INT").alias("csv")).collect() [Row(csv=Row(a=1))] - >>> data = [(1, '1')] >>> df = spark.createDataFrame(data, ("key", "value")) >>> df.select(from_csv(df.value, lit("a INT")).alias("csv")).collect() [Row(csv=Row(a=1))]