From d5145dce5b0838469ad8c52cac5eba4d50e36421 Mon Sep 17 00:00:00 2001 From: Adrian Olosutean Date: Fri, 7 Jan 2022 15:04:03 +0100 Subject: [PATCH 1/3] #17 DataFrameImplicits --- README.md | 30 ++- build.sbt | 1 + .../spark/commons/error/ErrorMessage.scala | 43 ++++ .../implicits/DataFrameImplicits.scala | 143 +++++++++++ .../commons/{ => schema}/SchemaUtils.scala | 26 +- .../implicits/DataFrameImplicitsSuite.scala | 237 ++++++++++++++++++ .../{ => schema}/SchemaUtilsSpec.scala | 2 +- 7 files changed, 479 insertions(+), 3 deletions(-) create mode 100644 src/main/scala/za/co/absa/spark/commons/error/ErrorMessage.scala create mode 100644 src/main/scala/za/co/absa/spark/commons/implicits/DataFrameImplicits.scala rename src/main/scala/za/co/absa/spark/commons/{ => schema}/SchemaUtils.scala (92%) create mode 100644 src/test/scala/za/co/absa/spark/commons/implicits/DataFrameImplicitsSuite.scala rename src/test/scala/za/co/absa/spark/commons/{ => schema}/SchemaUtilsSpec.scala (99%) diff --git a/README.md b/README.md index a2267190..2c411429 100644 --- a/README.md +++ b/README.md @@ -50,4 +50,32 @@ select to order and positionally filter columns of a DataFrame ```scala SchemaUtils.alignSchema(dataFrameToBeAligned, modelSchema) - ``` \ No newline at end of file + ``` +5. Getting a column with a unique name in case a schema is provided + + ```scala + SchemaUtils.getUniqueName(prefix, modelSchema) + ``` + +### DataFrameImplicits +_DataFrameImplicits_ provides methods for transformations on Dataframes + +1. Getting the string of the shown data of a dataframe + + ```scala + df.dataAsString() + + df.dataAsString(truncate) + + df.dataAsString(numRows, truncate) + + df.dataAsString(numRows, truncateNumber) + + df.dataAsString(numRows, truncate, vertical) + ``` + +2. Adds a column to a dataframe if it does not exist. If it exists, it will add an error in the error column + + ```scala + df.withColumnIfDoesNotExist(colName, colExpression) + ``` \ No newline at end of file diff --git a/build.sbt b/build.sbt index 2e52284a..28d91d91 100644 --- a/build.sbt +++ b/build.sbt @@ -29,6 +29,7 @@ libraryDependencies ++= List( "org.apache.spark" %% "spark-sql" % sparkVersion % "provided", "za.co.absa.commons" %% "commons" % "1.0.0", "za.co.absa" %% "spark-hofs" % "0.4.0", + "za.co.absa" %% "spark-hats" % "0.2.2", "org.scala-lang" % "scala-compiler" % scalaVersion.value, "org.scalatest" %% "scalatest" % "3.1.0" % Test, "org.scalatest" %% "scalatest-flatspec" % "3.2.0" % Test, diff --git a/src/main/scala/za/co/absa/spark/commons/error/ErrorMessage.scala b/src/main/scala/za/co/absa/spark/commons/error/ErrorMessage.scala new file mode 100644 index 00000000..c5fda072 --- /dev/null +++ b/src/main/scala/za/co/absa/spark/commons/error/ErrorMessage.scala @@ -0,0 +1,43 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.spark.commons.error + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.types.StructType + +/** + * Case class to represent an error message + * + * @param errType - Type or source of the error + * @param errCode - Internal error code + * @param errMsg - Textual description of the error + * @param errCol - The name of the column where the error occurred + * @param rawValues - Sequence of raw values (which are the potential culprits of the error) + * @param mappings - Sequence of Mappings i.e Mapping Table Column -> Equivalent Mapped Dataset column + */ +case class ErrorMessage(errType: String, errCode: String, errMsg: String, errCol: String, rawValues: Seq[String], mappings: Seq[Mapping] = Seq()) +case class Mapping(mappingTableColumn: String, mappedDatasetColumn: String) + +object ErrorMessage { + val errorColumnName = "errCol" + + def errorColSchema(implicit spark: SparkSession): StructType = { + import spark.implicits._ + spark.emptyDataset[ErrorMessage].schema + } +} + diff --git a/src/main/scala/za/co/absa/spark/commons/implicits/DataFrameImplicits.scala b/src/main/scala/za/co/absa/spark/commons/implicits/DataFrameImplicits.scala new file mode 100644 index 00000000..a99aeb8c --- /dev/null +++ b/src/main/scala/za/co/absa/spark/commons/implicits/DataFrameImplicits.scala @@ -0,0 +1,143 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.spark.commons.implicits + +import java.io.ByteArrayOutputStream + +import org.apache.log4j.{LogManager, Logger} +import org.apache.spark.sql.functions.{array, callUDF, col, lit, when} +import org.apache.spark.sql.{Column, DataFrame, SparkSession} +import za.co.absa.spark.commons.error.ErrorMessage +import za.co.absa.spark.commons.schema.SchemaUtils +import za.co.absa.spark.hats.transformations.NestedArrayTransformations + +import scala.collection.mutable + +object DataFrameImplicits { + + private val log: Logger = LogManager.getLogger(this.getClass) + + private val overWriteErrorFunction = "overWriteErr" + private val overWriteErrorType = "overWriteError" + private val overWriteErrorCode = "E000OW" + + implicit class DataFrameEnhancements(val df: DataFrame) { + + private def gatherData(showFnc: () => Unit): String = { + val outCapture = new ByteArrayOutputStream + Console.withOut(outCapture) { + showFnc() + } + val dfData = new String(outCapture.toByteArray).replace("\r\n", "\n") + dfData + } + + def dataAsString(): String = { + val showFnc: () => Unit = df.show + gatherData(showFnc) + } + + def dataAsString(truncate: Boolean): String = { + val showFnc: () => Unit = ()=>{df.show(truncate)} + gatherData(showFnc) + } + + def dataAsString(numRows: Int, truncate: Boolean): String = { + val showFnc: ()=>Unit = () => df.show(numRows, truncate) + gatherData(showFnc) + } + + def dataAsString(numRows: Int, truncate: Int): String = { + val showFnc: ()=>Unit = () => df.show(numRows, truncate) + gatherData(showFnc) + } + + def dataAsString(numRows: Int, truncate: Int, vertical: Boolean): String = { + val showFnc: ()=>Unit = () => df.show(numRows, truncate, vertical) + gatherData(showFnc) + } + + /** + * Adds a column to a dataframe if it does not exist + * + * @param colName A column to add if it does not exist already + * @param colExpr An expression for the column to add + * @return a new dataframe with the new column + */ + def withColumnIfDoesNotExist(colName: String, colExpr: Column): DataFrame = { + if (df.schema.exists(field => field.name.equalsIgnoreCase(colName))) { + log.warn(s"Column '$colName' already exists. The content of the column will be overwritten.") + overwriteWithErrorColumn(df, colName, colExpr) + } else { + df.withColumn(colName, colExpr) + } + } + + /** + * Overwrites a column with a value provided by an expression. + * If the value in the column does not match the one provided by the expression, an error will be + * added to the error column. + * + * @param df A dataframe + * @param colName A column to be overwritten + * @param colExpr An expression for the value to write + * @return a new dataframe with the value of the column being overwritten + */ + private def overwriteWithErrorColumn(df: DataFrame, colName: String, colExpr: Column): DataFrame = { + implicit val spark: SparkSession = df.sparkSession + spark.udf.register(overWriteErrorFunction, { (errCol: String, rawValue: String) => + ErrorMessage( + errType = overWriteErrorType, + errCode = overWriteErrorCode, + errMsg = "Special column value has changed", + errCol = errCol, + rawValues = Seq(rawValue)) + }) + spark.udf.register("arrayDistinctErrors", // this UDF is registered for _spark-hats_ library sake + (arr: mutable.WrappedArray[ErrorMessage]) => + if (arr != null) { + arr.distinct.filter((a: AnyRef) => a != null) + } else { + Seq[ErrorMessage]() + } + ) + + val tmpColumn = SchemaUtils.getUniqueName("tmpColumn", Some(df.schema)) + val tmpErrColumn = SchemaUtils.getUniqueName("tmpErrColumn", Some(df.schema)) + val litErrUdfCall = callUDF(overWriteErrorFunction, lit(colName), col(tmpColumn)) + + // Rename the original column to a temporary name. We need it for comparison. + val dfWithColRenamed = df.withColumnRenamed(colName, tmpColumn) + + // Add new column with the intended value + val dfWithIntendedColumn = dfWithColRenamed.withColumn(colName, colExpr) + + // Add a temporary error column containing errors if the original value does not match the intended one + val dfWithErrorColumn = dfWithIntendedColumn + .withColumn(tmpErrColumn, array(when(col(tmpColumn) =!= colExpr, litErrUdfCall).otherwise(null))) // scalastyle:ignore null + + // Gather all errors in errCol + val dfWithAggregatedErrColumn = NestedArrayTransformations + .gatherErrors(dfWithErrorColumn, tmpErrColumn, ErrorMessage.errorColumnName) + + // Drop the temporary column + dfWithAggregatedErrColumn.drop(tmpColumn) + } + + } + +} diff --git a/src/main/scala/za/co/absa/spark/commons/SchemaUtils.scala b/src/main/scala/za/co/absa/spark/commons/schema/SchemaUtils.scala similarity index 92% rename from src/main/scala/za/co/absa/spark/commons/SchemaUtils.scala rename to src/main/scala/za/co/absa/spark/commons/schema/SchemaUtils.scala index c3300a41..cae63b2a 100644 --- a/src/main/scala/za/co/absa/spark/commons/SchemaUtils.scala +++ b/src/main/scala/za/co/absa/spark/commons/schema/SchemaUtils.scala @@ -14,13 +14,15 @@ * limitations under the License. */ -package za.co.absa.spark.commons +package za.co.absa.spark.commons.schema import org.apache.spark.sql.functions.{col, struct} import org.apache.spark.sql.types.{ArrayType, DataType, StructField, StructType} import org.apache.spark.sql.{Column, DataFrame} import za.co.absa.spark.commons.adapters.HofsAdapter +import scala.util.Random + object SchemaUtils extends HofsAdapter { /** * Compares 2 array fields of a dataframe schema. @@ -230,4 +232,26 @@ object SchemaUtils extends HofsAdapter { private def getMapOfFields(schema: StructType): Map[String, StructField] = { schema.map(field => field.name.toLowerCase() -> field).toMap } + + /** + * Generate a unique column name + * + * @param prefix A prefix to use for the column name + * @param schema An optional schema to validate if the column already exists (a very low probability) + * @return A name that can be used as a unique column name + */ + def getUniqueName(prefix: String, schema: Option[StructType]): String = { + schema match { + case None => + s"${prefix}_${Random.nextLong().abs}" + case Some(sch) => + var exists = true + var columnName = "" + while (exists) { + columnName = s"${prefix}_${Random.nextLong().abs}" + exists = sch.fields.exists(_.name.compareToIgnoreCase(columnName) == 0) + } + columnName + } + } } diff --git a/src/test/scala/za/co/absa/spark/commons/implicits/DataFrameImplicitsSuite.scala b/src/test/scala/za/co/absa/spark/commons/implicits/DataFrameImplicitsSuite.scala new file mode 100644 index 00000000..954c60b1 --- /dev/null +++ b/src/test/scala/za/co/absa/spark/commons/implicits/DataFrameImplicitsSuite.scala @@ -0,0 +1,237 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.spark.commons.implicits + +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.functions.lit +import org.scalatest.funsuite.AnyFunSuite +import za.co.absa.spark.commons.test.SparkTestBase +import za.co.absa.spark.commons.implicits.DataFrameImplicits.DataFrameEnhancements + +class DataFrameImplicitsSuite extends AnyFunSuite with SparkTestBase { + import spark.implicits._ + + private val columnName = "data" + private val inputDataSeq = Seq( + "0123456789012345678901234", + "a", + "b", + "c", + "d", + "e", + "f", + "g", + "h", + "i", + "j", + "k", + "l", + "m", + "n", + "o", + "p", + "q", + "r", + "s", + "t", + "u", + "v", + "w", + "x", + "y", + "z" + ) + private val inputData = inputDataSeq.toDF(columnName) + + import za.co.absa.spark.commons.implicits.DataFrameImplicits.DataFrameEnhancements + + private def getDummyDataFrame: DataFrame = { + import spark.implicits._ + + Seq(1, 1, 1, 2, 1).toDF("value") + } + + private def cellText(text: String, width: Int, leftAlign: Boolean): String = { + val pad = " " * (width - text.length) + if (leftAlign) { + text + pad + } else { + pad + text + } + } + + private def line(width: Int): String = { + "+" + "-" * width + "+" + } + + private def header(width: Int, leftAlign: Boolean): String = { + val lineStr = line(width) + val title = cellText(columnName, width, leftAlign) + s"$lineStr\n|$title|\n$lineStr" + } + + private def cell(text: String, width: Int, leftAlign: Boolean): String = { + val inner = if (text.length > width) { + text.substring(0, width - 3) + "..." + } else { + cellText(text, width, leftAlign) + } + s"|$inner|" + } + + private def inputDataToString(width: Int, leftAlign: Boolean, limit: Option[Int] = Option(20)): String = { + val (extraLine, seq) = limit match { + case Some(n) => + val line = if (inputDataSeq.length > n) { + s"only showing top $n rows\n" + } else { + "" + } + (line, inputDataSeq.take(n)) + case None => + ("", inputDataSeq) + } + seq.foldLeft(header(width, leftAlign)) { (acc, item) => + acc + "\n" + cell(item, width, leftAlign) + } + "\n" + line(width) + s"\n$extraLine\n" + } + + test("Like show()") { + val result = inputData.dataAsString() + val leftAlign = false + val cellWidth = 20 + val expected = inputDataToString(cellWidth, leftAlign) + + assert(result == expected) + } + + test("Like show(false)") { + val result = inputData.dataAsString(false) + val leftAlign = true + val cellWidth = 25 + val expected = inputDataToString(cellWidth, leftAlign) + + assert(result == expected) + } + + test("Like show(3, true)") { + val result = inputData.dataAsString(3, true) + val leftAlign = false + val cellWidth = 20 + val expected = inputDataToString(cellWidth, leftAlign, Option(3)) + + assert(result == expected) + } + + test("Like show(30, false)") { + val result = inputData.dataAsString(30, false) + val leftAlign = true + val cellWidth = 25 + val expected = inputDataToString(cellWidth, leftAlign, Option(30)) + + assert(result == expected) + } + + + test("Like show(10, 10)") { + val result = inputData.dataAsString(10, 10) + val leftAlign = false + val cellWidth = 10 + val expected = inputDataToString(cellWidth, leftAlign, Option(10)) + + assert(result == expected) + } + + test("Like show(50, 50, false)") { + val result = inputData.dataAsString(50, 50, false) + val leftAlign = false + val cellWidth = 25 + val expected = inputDataToString(cellWidth, leftAlign, Option(50)) + + assert(result == expected) + } + + test("Test withColumnIfNotExist() when the column does not exist") { + val expectedOutput = + """+-----+---+ + ||value|foo| + |+-----+---+ + ||1 |1 | + ||1 |1 | + ||1 |1 | + ||2 |1 | + ||1 |1 | + |+-----+---+ + | + |""".stripMargin.replace("\r\n", "\n") + + val dfIn = getDummyDataFrame + val dfOut = dfIn.withColumnIfDoesNotExist("foo", lit(1)) + val actualOutput = dfOut.dataAsString(truncate = false) + + assert(dfOut.schema.length == 2) + assert(dfOut.schema.head.name == "value") + assert(dfOut.schema(1).name == "foo") + assert(actualOutput == expectedOutput) + } + + test("Test withColumnIfNotExist() when the column exists") { + val expectedOutput = + """+-----+----------------------------------------------------------------------------+ + ||value|errCol | + |+-----+----------------------------------------------------------------------------+ + ||1 |[] | + ||1 |[] | + ||1 |[] | + ||1 |[[overWriteError, E000OW, Special column value has changed, value, [2], []]]| + ||1 |[] | + |+-----+----------------------------------------------------------------------------+ + | + |""".stripMargin.replace("\r\n", "\n") + + val dfIn = getDummyDataFrame + val dfOut = dfIn.withColumnIfDoesNotExist("value", lit(1)) + val actualOutput = dfOut.dataAsString(truncate = false) + + assert(dfIn.schema.length == 1) + assert(dfIn.schema.head.name == "value") + assert(actualOutput == expectedOutput) + } + + test("Test withColumnIfNotExist() when the column exists, but has a different case") { + val expectedOutput = + """+-----+----------------------------------------------------------------------------+ + ||vAlUe|errCol | + |+-----+----------------------------------------------------------------------------+ + ||1 |[] | + ||1 |[] | + ||1 |[] | + ||1 |[[overWriteError, E000OW, Special column value has changed, vAlUe, [2], []]]| + ||1 |[] | + |+-----+----------------------------------------------------------------------------+ + | + |""".stripMargin.replace("\r\n", "\n") + + val dfIn = getDummyDataFrame + val dfOut = dfIn.withColumnIfDoesNotExist("vAlUe", lit(1)) + val actualOutput = dfOut.dataAsString(truncate = false) + + assert(dfIn.schema.length == 1) + assert(dfIn.schema.head.name == "value") + assert(actualOutput == expectedOutput) + } +} diff --git a/src/test/scala/za/co/absa/spark/commons/SchemaUtilsSpec.scala b/src/test/scala/za/co/absa/spark/commons/schema/SchemaUtilsSpec.scala similarity index 99% rename from src/test/scala/za/co/absa/spark/commons/SchemaUtilsSpec.scala rename to src/test/scala/za/co/absa/spark/commons/schema/SchemaUtilsSpec.scala index 9cf1cf7c..771e19ad 100644 --- a/src/test/scala/za/co/absa/spark/commons/SchemaUtilsSpec.scala +++ b/src/test/scala/za/co/absa/spark/commons/schema/SchemaUtilsSpec.scala @@ -14,7 +14,7 @@ * limitations under the License. */ -package za.co.absa.spark.commons +package za.co.absa.spark.commons.schema import org.apache.spark.sql.AnalysisException import org.scalatest.BeforeAndAfterAll From 2398eb7499cb267e2be9acdccf60e5fe30dfcb50 Mon Sep 17 00:00:00 2001 From: Adrian Olosutean Date: Mon, 10 Jan 2022 12:06:25 +0100 Subject: [PATCH 2/3] #17 fixed tests for Spark versions --- .../implicits/DataFrameImplicitsSuite.scala | 55 ++++++++++++++++++- 1 file changed, 52 insertions(+), 3 deletions(-) diff --git a/src/test/scala/za/co/absa/spark/commons/implicits/DataFrameImplicitsSuite.scala b/src/test/scala/za/co/absa/spark/commons/implicits/DataFrameImplicitsSuite.scala index 954c60b1..5e7e10a1 100644 --- a/src/test/scala/za/co/absa/spark/commons/implicits/DataFrameImplicitsSuite.scala +++ b/src/test/scala/za/co/absa/spark/commons/implicits/DataFrameImplicitsSuite.scala @@ -20,7 +20,6 @@ import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions.lit import org.scalatest.funsuite.AnyFunSuite import za.co.absa.spark.commons.test.SparkTestBase -import za.co.absa.spark.commons.implicits.DataFrameImplicits.DataFrameEnhancements class DataFrameImplicitsSuite extends AnyFunSuite with SparkTestBase { import spark.implicits._ @@ -189,7 +188,8 @@ class DataFrameImplicitsSuite extends AnyFunSuite with SparkTestBase { assert(actualOutput == expectedOutput) } - test("Test withColumnIfNotExist() when the column exists") { + test("Test withColumnIfNotExist() when the column exists, Spark 2.X") { + assume(!sys.props.getOrElse("SPARK_VERSION", "").startsWith("3.")) val expectedOutput = """+-----+----------------------------------------------------------------------------+ ||value|errCol | @@ -212,7 +212,8 @@ class DataFrameImplicitsSuite extends AnyFunSuite with SparkTestBase { assert(actualOutput == expectedOutput) } - test("Test withColumnIfNotExist() when the column exists, but has a different case") { + test("Test withColumnIfNotExist() when the column exists, but has a different case, Spark 2.X") { + assume(!sys.props.getOrElse("SPARK_VERSION", "").startsWith("3.")) val expectedOutput = """+-----+----------------------------------------------------------------------------+ ||vAlUe|errCol | @@ -234,4 +235,52 @@ class DataFrameImplicitsSuite extends AnyFunSuite with SparkTestBase { assert(dfIn.schema.head.name == "value") assert(actualOutput == expectedOutput) } + + test("Test withColumnIfNotExist() when the column exists, , Spark 3.X") { + assume(!sys.props.getOrElse("SPARK_VERSION", "").startsWith("2.")) + val expectedOutput = + """+-----+----------------------------------------------------------------------------+ + ||value|errCol | + |+-----+----------------------------------------------------------------------------+ + ||1 |[] | + ||1 |[] | + ||1 |[] | + ||1 |[{overWriteError, E000OW, Special column value has changed, value, [2], []}]| + ||1 |[] | + |+-----+----------------------------------------------------------------------------+ + | + |""".stripMargin.replace("\r\n", "\n") + + val dfIn = getDummyDataFrame + val dfOut = dfIn.withColumnIfDoesNotExist("value", lit(1)) + val actualOutput = dfOut.dataAsString(truncate = false) + + assert(dfIn.schema.length == 1) + assert(dfIn.schema.head.name == "value") + assert(actualOutput == expectedOutput) + } + + test("Test withColumnIfNotExist() when the column exists, but has a different case, Spark 3.X") { + assume(!sys.props.getOrElse("SPARK_VERSION", "").startsWith("2.")) + val expectedOutput = + """+-----+----------------------------------------------------------------------------+ + ||vAlUe|errCol | + |+-----+----------------------------------------------------------------------------+ + ||1 |[] | + ||1 |[] | + ||1 |[] | + ||1 |[{overWriteError, E000OW, Special column value has changed, vAlUe, [2], []}]| + ||1 |[] | + |+-----+----------------------------------------------------------------------------+ + | + |""".stripMargin.replace("\r\n", "\n") + + val dfIn = getDummyDataFrame + val dfOut = dfIn.withColumnIfDoesNotExist("vAlUe", lit(1)) + val actualOutput = dfOut.dataAsString(truncate = false) + + assert(dfIn.schema.length == 1) + assert(dfIn.schema.head.name == "value") + assert(actualOutput == expectedOutput) + } } From 579ca7cb9462835dbd9d1de029dc13004dcb67f2 Mon Sep 17 00:00:00 2001 From: Adrian Olosutean Date: Fri, 21 Jan 2022 10:40:55 +0100 Subject: [PATCH 3/3] #17 changes --- README.md | 12 +-- build.sbt | 1 - .../spark/commons/error/ErrorMessage.scala | 43 -------- .../implicits/DataFrameImplicits.scala | 73 +------------- .../spark/commons/schema/SchemaUtils.scala | 22 ----- .../implicits/DataFrameImplicitsSuite.scala | 97 +++++-------------- 6 files changed, 31 insertions(+), 217 deletions(-) delete mode 100644 src/main/scala/za/co/absa/spark/commons/error/ErrorMessage.scala diff --git a/README.md b/README.md index c57e5b8e..d9401a6d 100644 --- a/README.md +++ b/README.md @@ -51,12 +51,6 @@ select to order and positionally filter columns of a DataFrame ```scala SchemaUtils.alignSchema(dataFrameToBeAligned, modelSchema) ``` -5. Getting a column with a unique name in case a schema is provided - - ```scala - SchemaUtils.getUniqueName(prefix, modelSchema) - ``` - # Spark Version Guard A class which checks if the Spark job version is compatible with the Spark Versions supported by the library @@ -79,7 +73,7 @@ SparkVersionGuard.fromSpark3XCompatibilitySettings.ensureSparkVersionCompatibili ### DataFrameImplicits _DataFrameImplicits_ provides methods for transformations on Dataframes -1. Getting the string of the shown data of a dataframe +1. Getting the string of the data of the dataframe in similar fashion as the `show` function present them. ```scala df.dataAsString() @@ -93,8 +87,8 @@ _DataFrameImplicits_ provides methods for transformations on Dataframes df.dataAsString(numRows, truncate, vertical) ``` -2. Adds a column to a dataframe if it does not exist. If it exists, it will add an error in the error column +2. Adds a column to a dataframe if it does not exist. If it exists, it will apply the provided function ```scala - df.withColumnIfDoesNotExist(colName, colExpression) + df.withColumnIfDoesNotExist((df: DataFrame, _) => df)(colName, colExpression) ``` \ No newline at end of file diff --git a/build.sbt b/build.sbt index 28d91d91..2e52284a 100644 --- a/build.sbt +++ b/build.sbt @@ -29,7 +29,6 @@ libraryDependencies ++= List( "org.apache.spark" %% "spark-sql" % sparkVersion % "provided", "za.co.absa.commons" %% "commons" % "1.0.0", "za.co.absa" %% "spark-hofs" % "0.4.0", - "za.co.absa" %% "spark-hats" % "0.2.2", "org.scala-lang" % "scala-compiler" % scalaVersion.value, "org.scalatest" %% "scalatest" % "3.1.0" % Test, "org.scalatest" %% "scalatest-flatspec" % "3.2.0" % Test, diff --git a/src/main/scala/za/co/absa/spark/commons/error/ErrorMessage.scala b/src/main/scala/za/co/absa/spark/commons/error/ErrorMessage.scala deleted file mode 100644 index c5fda072..00000000 --- a/src/main/scala/za/co/absa/spark/commons/error/ErrorMessage.scala +++ /dev/null @@ -1,43 +0,0 @@ -/* - * Copyright 2021 ABSA Group Limited - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package za.co.absa.spark.commons.error - -import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.types.StructType - -/** - * Case class to represent an error message - * - * @param errType - Type or source of the error - * @param errCode - Internal error code - * @param errMsg - Textual description of the error - * @param errCol - The name of the column where the error occurred - * @param rawValues - Sequence of raw values (which are the potential culprits of the error) - * @param mappings - Sequence of Mappings i.e Mapping Table Column -> Equivalent Mapped Dataset column - */ -case class ErrorMessage(errType: String, errCode: String, errMsg: String, errCol: String, rawValues: Seq[String], mappings: Seq[Mapping] = Seq()) -case class Mapping(mappingTableColumn: String, mappedDatasetColumn: String) - -object ErrorMessage { - val errorColumnName = "errCol" - - def errorColSchema(implicit spark: SparkSession): StructType = { - import spark.implicits._ - spark.emptyDataset[ErrorMessage].schema - } -} - diff --git a/src/main/scala/za/co/absa/spark/commons/implicits/DataFrameImplicits.scala b/src/main/scala/za/co/absa/spark/commons/implicits/DataFrameImplicits.scala index a99aeb8c..4905b1d6 100644 --- a/src/main/scala/za/co/absa/spark/commons/implicits/DataFrameImplicits.scala +++ b/src/main/scala/za/co/absa/spark/commons/implicits/DataFrameImplicits.scala @@ -18,23 +18,10 @@ package za.co.absa.spark.commons.implicits import java.io.ByteArrayOutputStream -import org.apache.log4j.{LogManager, Logger} -import org.apache.spark.sql.functions.{array, callUDF, col, lit, when} -import org.apache.spark.sql.{Column, DataFrame, SparkSession} -import za.co.absa.spark.commons.error.ErrorMessage -import za.co.absa.spark.commons.schema.SchemaUtils -import za.co.absa.spark.hats.transformations.NestedArrayTransformations - -import scala.collection.mutable +import org.apache.spark.sql.{Column, DataFrame} object DataFrameImplicits { - private val log: Logger = LogManager.getLogger(this.getClass) - - private val overWriteErrorFunction = "overWriteErr" - private val overWriteErrorType = "overWriteError" - private val overWriteErrorCode = "E000OW" - implicit class DataFrameEnhancements(val df: DataFrame) { private def gatherData(showFnc: () => Unit): String = { @@ -75,69 +62,17 @@ object DataFrameImplicits { * Adds a column to a dataframe if it does not exist * * @param colName A column to add if it does not exist already + * @param ifExists A function to apply when the column already exists * @param colExpr An expression for the column to add * @return a new dataframe with the new column */ - def withColumnIfDoesNotExist(colName: String, colExpr: Column): DataFrame = { + def withColumnIfDoesNotExist(ifExists: (DataFrame, String) => DataFrame)(colName: String, colExpr: Column): DataFrame = { if (df.schema.exists(field => field.name.equalsIgnoreCase(colName))) { - log.warn(s"Column '$colName' already exists. The content of the column will be overwritten.") - overwriteWithErrorColumn(df, colName, colExpr) + ifExists(df, colName) } else { df.withColumn(colName, colExpr) } } - - /** - * Overwrites a column with a value provided by an expression. - * If the value in the column does not match the one provided by the expression, an error will be - * added to the error column. - * - * @param df A dataframe - * @param colName A column to be overwritten - * @param colExpr An expression for the value to write - * @return a new dataframe with the value of the column being overwritten - */ - private def overwriteWithErrorColumn(df: DataFrame, colName: String, colExpr: Column): DataFrame = { - implicit val spark: SparkSession = df.sparkSession - spark.udf.register(overWriteErrorFunction, { (errCol: String, rawValue: String) => - ErrorMessage( - errType = overWriteErrorType, - errCode = overWriteErrorCode, - errMsg = "Special column value has changed", - errCol = errCol, - rawValues = Seq(rawValue)) - }) - spark.udf.register("arrayDistinctErrors", // this UDF is registered for _spark-hats_ library sake - (arr: mutable.WrappedArray[ErrorMessage]) => - if (arr != null) { - arr.distinct.filter((a: AnyRef) => a != null) - } else { - Seq[ErrorMessage]() - } - ) - - val tmpColumn = SchemaUtils.getUniqueName("tmpColumn", Some(df.schema)) - val tmpErrColumn = SchemaUtils.getUniqueName("tmpErrColumn", Some(df.schema)) - val litErrUdfCall = callUDF(overWriteErrorFunction, lit(colName), col(tmpColumn)) - - // Rename the original column to a temporary name. We need it for comparison. - val dfWithColRenamed = df.withColumnRenamed(colName, tmpColumn) - - // Add new column with the intended value - val dfWithIntendedColumn = dfWithColRenamed.withColumn(colName, colExpr) - - // Add a temporary error column containing errors if the original value does not match the intended one - val dfWithErrorColumn = dfWithIntendedColumn - .withColumn(tmpErrColumn, array(when(col(tmpColumn) =!= colExpr, litErrUdfCall).otherwise(null))) // scalastyle:ignore null - - // Gather all errors in errCol - val dfWithAggregatedErrColumn = NestedArrayTransformations - .gatherErrors(dfWithErrorColumn, tmpErrColumn, ErrorMessage.errorColumnName) - - // Drop the temporary column - dfWithAggregatedErrColumn.drop(tmpColumn) - } - } } diff --git a/src/main/scala/za/co/absa/spark/commons/schema/SchemaUtils.scala b/src/main/scala/za/co/absa/spark/commons/schema/SchemaUtils.scala index cae63b2a..9de12d61 100644 --- a/src/main/scala/za/co/absa/spark/commons/schema/SchemaUtils.scala +++ b/src/main/scala/za/co/absa/spark/commons/schema/SchemaUtils.scala @@ -232,26 +232,4 @@ object SchemaUtils extends HofsAdapter { private def getMapOfFields(schema: StructType): Map[String, StructField] = { schema.map(field => field.name.toLowerCase() -> field).toMap } - - /** - * Generate a unique column name - * - * @param prefix A prefix to use for the column name - * @param schema An optional schema to validate if the column already exists (a very low probability) - * @return A name that can be used as a unique column name - */ - def getUniqueName(prefix: String, schema: Option[StructType]): String = { - schema match { - case None => - s"${prefix}_${Random.nextLong().abs}" - case Some(sch) => - var exists = true - var columnName = "" - while (exists) { - columnName = s"${prefix}_${Random.nextLong().abs}" - exists = sch.fields.exists(_.name.compareToIgnoreCase(columnName) == 0) - } - columnName - } - } } diff --git a/src/test/scala/za/co/absa/spark/commons/implicits/DataFrameImplicitsSuite.scala b/src/test/scala/za/co/absa/spark/commons/implicits/DataFrameImplicitsSuite.scala index 5e7e10a1..ac5a5780 100644 --- a/src/test/scala/za/co/absa/spark/commons/implicits/DataFrameImplicitsSuite.scala +++ b/src/test/scala/za/co/absa/spark/commons/implicits/DataFrameImplicitsSuite.scala @@ -179,7 +179,7 @@ class DataFrameImplicitsSuite extends AnyFunSuite with SparkTestBase { |""".stripMargin.replace("\r\n", "\n") val dfIn = getDummyDataFrame - val dfOut = dfIn.withColumnIfDoesNotExist("foo", lit(1)) + val dfOut = dfIn.withColumnIfDoesNotExist((df: DataFrame, _) => df)("foo", lit(1)) val actualOutput = dfOut.dataAsString(truncate = false) assert(dfOut.schema.length == 2) @@ -188,23 +188,22 @@ class DataFrameImplicitsSuite extends AnyFunSuite with SparkTestBase { assert(actualOutput == expectedOutput) } - test("Test withColumnIfNotExist() when the column exists, Spark 2.X") { - assume(!sys.props.getOrElse("SPARK_VERSION", "").startsWith("3.")) + test("Test withColumnIfNotExist() when the column exists") { val expectedOutput = - """+-----+----------------------------------------------------------------------------+ - ||value|errCol | - |+-----+----------------------------------------------------------------------------+ - ||1 |[] | - ||1 |[] | - ||1 |[] | - ||1 |[[overWriteError, E000OW, Special column value has changed, value, [2], []]]| - ||1 |[] | - |+-----+----------------------------------------------------------------------------+ + """+-----+ + ||value| + |+-----+ + ||1 | + ||1 | + ||1 | + ||2 | + ||1 | + |+-----+ | |""".stripMargin.replace("\r\n", "\n") val dfIn = getDummyDataFrame - val dfOut = dfIn.withColumnIfDoesNotExist("value", lit(1)) + val dfOut = dfIn.withColumnIfDoesNotExist((df: DataFrame, _) => df)("value", lit(1)) val actualOutput = dfOut.dataAsString(truncate = false) assert(dfIn.schema.length == 1) @@ -212,71 +211,23 @@ class DataFrameImplicitsSuite extends AnyFunSuite with SparkTestBase { assert(actualOutput == expectedOutput) } - test("Test withColumnIfNotExist() when the column exists, but has a different case, Spark 2.X") { - assume(!sys.props.getOrElse("SPARK_VERSION", "").startsWith("3.")) + test("Test withColumnIfNotExist() when the column exists, but has a different case") { val expectedOutput = - """+-----+----------------------------------------------------------------------------+ - ||vAlUe|errCol | - |+-----+----------------------------------------------------------------------------+ - ||1 |[] | - ||1 |[] | - ||1 |[] | - ||1 |[[overWriteError, E000OW, Special column value has changed, vAlUe, [2], []]]| - ||1 |[] | - |+-----+----------------------------------------------------------------------------+ + """+-----+------+ + ||value|errCol| + |+-----+------+ + ||1 |[] | + ||1 |[] | + ||1 |[] | + ||2 |[] | + ||1 |[] | + |+-----+------+ | |""".stripMargin.replace("\r\n", "\n") val dfIn = getDummyDataFrame - val dfOut = dfIn.withColumnIfDoesNotExist("vAlUe", lit(1)) - val actualOutput = dfOut.dataAsString(truncate = false) - - assert(dfIn.schema.length == 1) - assert(dfIn.schema.head.name == "value") - assert(actualOutput == expectedOutput) - } - - test("Test withColumnIfNotExist() when the column exists, , Spark 3.X") { - assume(!sys.props.getOrElse("SPARK_VERSION", "").startsWith("2.")) - val expectedOutput = - """+-----+----------------------------------------------------------------------------+ - ||value|errCol | - |+-----+----------------------------------------------------------------------------+ - ||1 |[] | - ||1 |[] | - ||1 |[] | - ||1 |[{overWriteError, E000OW, Special column value has changed, value, [2], []}]| - ||1 |[] | - |+-----+----------------------------------------------------------------------------+ - | - |""".stripMargin.replace("\r\n", "\n") - - val dfIn = getDummyDataFrame - val dfOut = dfIn.withColumnIfDoesNotExist("value", lit(1)) - val actualOutput = dfOut.dataAsString(truncate = false) - - assert(dfIn.schema.length == 1) - assert(dfIn.schema.head.name == "value") - assert(actualOutput == expectedOutput) - } - - test("Test withColumnIfNotExist() when the column exists, but has a different case, Spark 3.X") { - assume(!sys.props.getOrElse("SPARK_VERSION", "").startsWith("2.")) - val expectedOutput = - """+-----+----------------------------------------------------------------------------+ - ||vAlUe|errCol | - |+-----+----------------------------------------------------------------------------+ - ||1 |[] | - ||1 |[] | - ||1 |[] | - ||1 |[{overWriteError, E000OW, Special column value has changed, vAlUe, [2], []}]| - ||1 |[] | - |+-----+----------------------------------------------------------------------------+ - | - |""".stripMargin.replace("\r\n", "\n") - - val dfIn = getDummyDataFrame - val dfOut = dfIn.withColumnIfDoesNotExist("vAlUe", lit(1)) + val function: (DataFrame, String) => DataFrame = (df: DataFrame, _) => df.withColumn("errCol", lit(Array.emptyIntArray)) + val dfOut = dfIn.withColumnIfDoesNotExist(function)("vAlUe", lit(1)) val actualOutput = dfOut.dataAsString(truncate = false) assert(dfIn.schema.length == 1)