diff --git a/README.md b/README.md index 72d24caf..d9401a6d 100644 --- a/README.md +++ b/README.md @@ -50,8 +50,7 @@ select to order and positionally filter columns of a DataFrame ```scala SchemaUtils.alignSchema(dataFrameToBeAligned, modelSchema) - ``` - + ``` # Spark Version Guard A class which checks if the Spark job version is compatible with the Spark Versions supported by the library @@ -69,4 +68,27 @@ SparkVersionGuard.fromSpark2XCompatibilitySettings.ensureSparkVersionCompatibili Checking for 3.X versions ```scala SparkVersionGuard.fromSpark3XCompatibilitySettings.ensureSparkVersionCompatibility(SPARK_VERSION) -``` \ No newline at end of file +``` + +### DataFrameImplicits +_DataFrameImplicits_ provides methods for transformations on Dataframes + +1. Getting the string of the data of the dataframe in similar fashion as the `show` function present them. + + ```scala + df.dataAsString() + + df.dataAsString(truncate) + + df.dataAsString(numRows, truncate) + + df.dataAsString(numRows, truncateNumber) + + df.dataAsString(numRows, truncate, vertical) + ``` + +2. Adds a column to a dataframe if it does not exist. If it exists, it will apply the provided function + + ```scala + df.withColumnIfDoesNotExist((df: DataFrame, _) => df)(colName, colExpression) + ``` \ No newline at end of file diff --git a/src/main/scala/za/co/absa/spark/commons/implicits/DataFrameImplicits.scala b/src/main/scala/za/co/absa/spark/commons/implicits/DataFrameImplicits.scala new file mode 100644 index 00000000..4905b1d6 --- /dev/null +++ b/src/main/scala/za/co/absa/spark/commons/implicits/DataFrameImplicits.scala @@ -0,0 +1,78 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.spark.commons.implicits + +import java.io.ByteArrayOutputStream + +import org.apache.spark.sql.{Column, DataFrame} + +object DataFrameImplicits { + + implicit class DataFrameEnhancements(val df: DataFrame) { + + private def gatherData(showFnc: () => Unit): String = { + val outCapture = new ByteArrayOutputStream + Console.withOut(outCapture) { + showFnc() + } + val dfData = new String(outCapture.toByteArray).replace("\r\n", "\n") + dfData + } + + def dataAsString(): String = { + val showFnc: () => Unit = df.show + gatherData(showFnc) + } + + def dataAsString(truncate: Boolean): String = { + val showFnc: () => Unit = ()=>{df.show(truncate)} + gatherData(showFnc) + } + + def dataAsString(numRows: Int, truncate: Boolean): String = { + val showFnc: ()=>Unit = () => df.show(numRows, truncate) + gatherData(showFnc) + } + + def dataAsString(numRows: Int, truncate: Int): String = { + val showFnc: ()=>Unit = () => df.show(numRows, truncate) + gatherData(showFnc) + } + + def dataAsString(numRows: Int, truncate: Int, vertical: Boolean): String = { + val showFnc: ()=>Unit = () => df.show(numRows, truncate, vertical) + gatherData(showFnc) + } + + /** + * Adds a column to a dataframe if it does not exist + * + * @param colName A column to add if it does not exist already + * @param ifExists A function to apply when the column already exists + * @param colExpr An expression for the column to add + * @return a new dataframe with the new column + */ + def withColumnIfDoesNotExist(ifExists: (DataFrame, String) => DataFrame)(colName: String, colExpr: Column): DataFrame = { + if (df.schema.exists(field => field.name.equalsIgnoreCase(colName))) { + ifExists(df, colName) + } else { + df.withColumn(colName, colExpr) + } + } + } + +} diff --git a/src/main/scala/za/co/absa/spark/commons/SchemaUtils.scala b/src/main/scala/za/co/absa/spark/commons/schema/SchemaUtils.scala similarity index 99% rename from src/main/scala/za/co/absa/spark/commons/SchemaUtils.scala rename to src/main/scala/za/co/absa/spark/commons/schema/SchemaUtils.scala index c3300a41..9de12d61 100644 --- a/src/main/scala/za/co/absa/spark/commons/SchemaUtils.scala +++ b/src/main/scala/za/co/absa/spark/commons/schema/SchemaUtils.scala @@ -14,13 +14,15 @@ * limitations under the License. */ -package za.co.absa.spark.commons +package za.co.absa.spark.commons.schema import org.apache.spark.sql.functions.{col, struct} import org.apache.spark.sql.types.{ArrayType, DataType, StructField, StructType} import org.apache.spark.sql.{Column, DataFrame} import za.co.absa.spark.commons.adapters.HofsAdapter +import scala.util.Random + object SchemaUtils extends HofsAdapter { /** * Compares 2 array fields of a dataframe schema. diff --git a/src/test/scala/za/co/absa/spark/commons/implicits/DataFrameImplicitsSuite.scala b/src/test/scala/za/co/absa/spark/commons/implicits/DataFrameImplicitsSuite.scala new file mode 100644 index 00000000..ac5a5780 --- /dev/null +++ b/src/test/scala/za/co/absa/spark/commons/implicits/DataFrameImplicitsSuite.scala @@ -0,0 +1,237 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.spark.commons.implicits + +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.functions.lit +import org.scalatest.funsuite.AnyFunSuite +import za.co.absa.spark.commons.test.SparkTestBase + +class DataFrameImplicitsSuite extends AnyFunSuite with SparkTestBase { + import spark.implicits._ + + private val columnName = "data" + private val inputDataSeq = Seq( + "0123456789012345678901234", + "a", + "b", + "c", + "d", + "e", + "f", + "g", + "h", + "i", + "j", + "k", + "l", + "m", + "n", + "o", + "p", + "q", + "r", + "s", + "t", + "u", + "v", + "w", + "x", + "y", + "z" + ) + private val inputData = inputDataSeq.toDF(columnName) + + import za.co.absa.spark.commons.implicits.DataFrameImplicits.DataFrameEnhancements + + private def getDummyDataFrame: DataFrame = { + import spark.implicits._ + + Seq(1, 1, 1, 2, 1).toDF("value") + } + + private def cellText(text: String, width: Int, leftAlign: Boolean): String = { + val pad = " " * (width - text.length) + if (leftAlign) { + text + pad + } else { + pad + text + } + } + + private def line(width: Int): String = { + "+" + "-" * width + "+" + } + + private def header(width: Int, leftAlign: Boolean): String = { + val lineStr = line(width) + val title = cellText(columnName, width, leftAlign) + s"$lineStr\n|$title|\n$lineStr" + } + + private def cell(text: String, width: Int, leftAlign: Boolean): String = { + val inner = if (text.length > width) { + text.substring(0, width - 3) + "..." + } else { + cellText(text, width, leftAlign) + } + s"|$inner|" + } + + private def inputDataToString(width: Int, leftAlign: Boolean, limit: Option[Int] = Option(20)): String = { + val (extraLine, seq) = limit match { + case Some(n) => + val line = if (inputDataSeq.length > n) { + s"only showing top $n rows\n" + } else { + "" + } + (line, inputDataSeq.take(n)) + case None => + ("", inputDataSeq) + } + seq.foldLeft(header(width, leftAlign)) { (acc, item) => + acc + "\n" + cell(item, width, leftAlign) + } + "\n" + line(width) + s"\n$extraLine\n" + } + + test("Like show()") { + val result = inputData.dataAsString() + val leftAlign = false + val cellWidth = 20 + val expected = inputDataToString(cellWidth, leftAlign) + + assert(result == expected) + } + + test("Like show(false)") { + val result = inputData.dataAsString(false) + val leftAlign = true + val cellWidth = 25 + val expected = inputDataToString(cellWidth, leftAlign) + + assert(result == expected) + } + + test("Like show(3, true)") { + val result = inputData.dataAsString(3, true) + val leftAlign = false + val cellWidth = 20 + val expected = inputDataToString(cellWidth, leftAlign, Option(3)) + + assert(result == expected) + } + + test("Like show(30, false)") { + val result = inputData.dataAsString(30, false) + val leftAlign = true + val cellWidth = 25 + val expected = inputDataToString(cellWidth, leftAlign, Option(30)) + + assert(result == expected) + } + + + test("Like show(10, 10)") { + val result = inputData.dataAsString(10, 10) + val leftAlign = false + val cellWidth = 10 + val expected = inputDataToString(cellWidth, leftAlign, Option(10)) + + assert(result == expected) + } + + test("Like show(50, 50, false)") { + val result = inputData.dataAsString(50, 50, false) + val leftAlign = false + val cellWidth = 25 + val expected = inputDataToString(cellWidth, leftAlign, Option(50)) + + assert(result == expected) + } + + test("Test withColumnIfNotExist() when the column does not exist") { + val expectedOutput = + """+-----+---+ + ||value|foo| + |+-----+---+ + ||1 |1 | + ||1 |1 | + ||1 |1 | + ||2 |1 | + ||1 |1 | + |+-----+---+ + | + |""".stripMargin.replace("\r\n", "\n") + + val dfIn = getDummyDataFrame + val dfOut = dfIn.withColumnIfDoesNotExist((df: DataFrame, _) => df)("foo", lit(1)) + val actualOutput = dfOut.dataAsString(truncate = false) + + assert(dfOut.schema.length == 2) + assert(dfOut.schema.head.name == "value") + assert(dfOut.schema(1).name == "foo") + assert(actualOutput == expectedOutput) + } + + test("Test withColumnIfNotExist() when the column exists") { + val expectedOutput = + """+-----+ + ||value| + |+-----+ + ||1 | + ||1 | + ||1 | + ||2 | + ||1 | + |+-----+ + | + |""".stripMargin.replace("\r\n", "\n") + + val dfIn = getDummyDataFrame + val dfOut = dfIn.withColumnIfDoesNotExist((df: DataFrame, _) => df)("value", lit(1)) + val actualOutput = dfOut.dataAsString(truncate = false) + + assert(dfIn.schema.length == 1) + assert(dfIn.schema.head.name == "value") + assert(actualOutput == expectedOutput) + } + + test("Test withColumnIfNotExist() when the column exists, but has a different case") { + val expectedOutput = + """+-----+------+ + ||value|errCol| + |+-----+------+ + ||1 |[] | + ||1 |[] | + ||1 |[] | + ||2 |[] | + ||1 |[] | + |+-----+------+ + | + |""".stripMargin.replace("\r\n", "\n") + + val dfIn = getDummyDataFrame + val function: (DataFrame, String) => DataFrame = (df: DataFrame, _) => df.withColumn("errCol", lit(Array.emptyIntArray)) + val dfOut = dfIn.withColumnIfDoesNotExist(function)("vAlUe", lit(1)) + val actualOutput = dfOut.dataAsString(truncate = false) + + assert(dfIn.schema.length == 1) + assert(dfIn.schema.head.name == "value") + assert(actualOutput == expectedOutput) + } +} diff --git a/src/test/scala/za/co/absa/spark/commons/SchemaUtilsSpec.scala b/src/test/scala/za/co/absa/spark/commons/schema/SchemaUtilsSpec.scala similarity index 99% rename from src/test/scala/za/co/absa/spark/commons/SchemaUtilsSpec.scala rename to src/test/scala/za/co/absa/spark/commons/schema/SchemaUtilsSpec.scala index 9cf1cf7c..771e19ad 100644 --- a/src/test/scala/za/co/absa/spark/commons/SchemaUtilsSpec.scala +++ b/src/test/scala/za/co/absa/spark/commons/schema/SchemaUtilsSpec.scala @@ -14,7 +14,7 @@ * limitations under the License. */ -package za.co.absa.spark.commons +package za.co.absa.spark.commons.schema import org.apache.spark.sql.AnalysisException import org.scalatest.BeforeAndAfterAll