Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 25 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,7 @@ select to order and positionally filter columns of a DataFrame

```scala
SchemaUtils.alignSchema(dataFrameToBeAligned, modelSchema)
```

```
# Spark Version Guard

A class which checks if the Spark job version is compatible with the Spark Versions supported by the library
Expand All @@ -69,4 +68,27 @@ SparkVersionGuard.fromSpark2XCompatibilitySettings.ensureSparkVersionCompatibili
Checking for 3.X versions
```scala
SparkVersionGuard.fromSpark3XCompatibilitySettings.ensureSparkVersionCompatibility(SPARK_VERSION)
```
```

### DataFrameImplicits
_DataFrameImplicits_ provides methods for transformations on Dataframes

1. Getting the string of the data of the dataframe in similar fashion as the `show` function present them.

```scala
df.dataAsString()

df.dataAsString(truncate)

df.dataAsString(numRows, truncate)

df.dataAsString(numRows, truncateNumber)

df.dataAsString(numRows, truncate, vertical)
```

2. Adds a column to a dataframe if it does not exist. If it exists, it will apply the provided function

```scala
df.withColumnIfDoesNotExist((df: DataFrame, _) => df)(colName, colExpression)
```
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
/*
* Copyright 2021 ABSA Group Limited
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package za.co.absa.spark.commons.implicits

import java.io.ByteArrayOutputStream

import org.apache.spark.sql.{Column, DataFrame}

object DataFrameImplicits {

implicit class DataFrameEnhancements(val df: DataFrame) {

private def gatherData(showFnc: () => Unit): String = {
val outCapture = new ByteArrayOutputStream
Console.withOut(outCapture) {
showFnc()
}
val dfData = new String(outCapture.toByteArray).replace("\r\n", "\n")
dfData
}

def dataAsString(): String = {
val showFnc: () => Unit = df.show
gatherData(showFnc)
}

def dataAsString(truncate: Boolean): String = {
val showFnc: () => Unit = ()=>{df.show(truncate)}
gatherData(showFnc)
}

def dataAsString(numRows: Int, truncate: Boolean): String = {
val showFnc: ()=>Unit = () => df.show(numRows, truncate)
gatherData(showFnc)
}

def dataAsString(numRows: Int, truncate: Int): String = {
val showFnc: ()=>Unit = () => df.show(numRows, truncate)
gatherData(showFnc)
}

def dataAsString(numRows: Int, truncate: Int, vertical: Boolean): String = {
val showFnc: ()=>Unit = () => df.show(numRows, truncate, vertical)
gatherData(showFnc)
}

/**
* Adds a column to a dataframe if it does not exist
*
* @param colName A column to add if it does not exist already
* @param ifExists A function to apply when the column already exists
* @param colExpr An expression for the column to add
* @return a new dataframe with the new column
*/
def withColumnIfDoesNotExist(ifExists: (DataFrame, String) => DataFrame)(colName: String, colExpr: Column): DataFrame = {
if (df.schema.exists(field => field.name.equalsIgnoreCase(colName))) {
ifExists(df, colName)
} else {
df.withColumn(colName, colExpr)
}
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@
* limitations under the License.
*/

package za.co.absa.spark.commons
package za.co.absa.spark.commons.schema

import org.apache.spark.sql.functions.{col, struct}
import org.apache.spark.sql.types.{ArrayType, DataType, StructField, StructType}
import org.apache.spark.sql.{Column, DataFrame}
import za.co.absa.spark.commons.adapters.HofsAdapter

import scala.util.Random

object SchemaUtils extends HofsAdapter {
/**
* Compares 2 array fields of a dataframe schema.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,237 @@
/*
* Copyright 2021 ABSA Group Limited
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package za.co.absa.spark.commons.implicits

import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions.lit
import org.scalatest.funsuite.AnyFunSuite
import za.co.absa.spark.commons.test.SparkTestBase

class DataFrameImplicitsSuite extends AnyFunSuite with SparkTestBase {
import spark.implicits._

private val columnName = "data"
private val inputDataSeq = Seq(
"0123456789012345678901234",
"a",
"b",
"c",
"d",
"e",
"f",
"g",
"h",
"i",
"j",
"k",
"l",
"m",
"n",
"o",
"p",
"q",
"r",
"s",
"t",
"u",
"v",
"w",
"x",
"y",
"z"
)
private val inputData = inputDataSeq.toDF(columnName)

import za.co.absa.spark.commons.implicits.DataFrameImplicits.DataFrameEnhancements

private def getDummyDataFrame: DataFrame = {
import spark.implicits._

Seq(1, 1, 1, 2, 1).toDF("value")
}

private def cellText(text: String, width: Int, leftAlign: Boolean): String = {
val pad = " " * (width - text.length)
if (leftAlign) {
text + pad
} else {
pad + text
}
}

private def line(width: Int): String = {
"+" + "-" * width + "+"
}

private def header(width: Int, leftAlign: Boolean): String = {
val lineStr = line(width)
val title = cellText(columnName, width, leftAlign)
s"$lineStr\n|$title|\n$lineStr"
}

private def cell(text: String, width: Int, leftAlign: Boolean): String = {
val inner = if (text.length > width) {
text.substring(0, width - 3) + "..."
} else {
cellText(text, width, leftAlign)
}
s"|$inner|"
}

private def inputDataToString(width: Int, leftAlign: Boolean, limit: Option[Int] = Option(20)): String = {
val (extraLine, seq) = limit match {
case Some(n) =>
val line = if (inputDataSeq.length > n) {
s"only showing top $n rows\n"
} else {
""
}
(line, inputDataSeq.take(n))
case None =>
("", inputDataSeq)
}
seq.foldLeft(header(width, leftAlign)) { (acc, item) =>
acc + "\n" + cell(item, width, leftAlign)
} + "\n" + line(width) + s"\n$extraLine\n"
}

test("Like show()") {
val result = inputData.dataAsString()
val leftAlign = false
val cellWidth = 20
val expected = inputDataToString(cellWidth, leftAlign)

assert(result == expected)
}

test("Like show(false)") {
val result = inputData.dataAsString(false)
val leftAlign = true
val cellWidth = 25
val expected = inputDataToString(cellWidth, leftAlign)

assert(result == expected)
}

test("Like show(3, true)") {
val result = inputData.dataAsString(3, true)
val leftAlign = false
val cellWidth = 20
val expected = inputDataToString(cellWidth, leftAlign, Option(3))

assert(result == expected)
}

test("Like show(30, false)") {
val result = inputData.dataAsString(30, false)
val leftAlign = true
val cellWidth = 25
val expected = inputDataToString(cellWidth, leftAlign, Option(30))

assert(result == expected)
}


test("Like show(10, 10)") {
val result = inputData.dataAsString(10, 10)
val leftAlign = false
val cellWidth = 10
val expected = inputDataToString(cellWidth, leftAlign, Option(10))

assert(result == expected)
}

test("Like show(50, 50, false)") {
val result = inputData.dataAsString(50, 50, false)
val leftAlign = false
val cellWidth = 25
val expected = inputDataToString(cellWidth, leftAlign, Option(50))

assert(result == expected)
}

test("Test withColumnIfNotExist() when the column does not exist") {
val expectedOutput =
"""+-----+---+
||value|foo|
|+-----+---+
||1 |1 |
||1 |1 |
||1 |1 |
||2 |1 |
||1 |1 |
|+-----+---+
|
|""".stripMargin.replace("\r\n", "\n")

val dfIn = getDummyDataFrame
val dfOut = dfIn.withColumnIfDoesNotExist((df: DataFrame, _) => df)("foo", lit(1))
val actualOutput = dfOut.dataAsString(truncate = false)

assert(dfOut.schema.length == 2)
assert(dfOut.schema.head.name == "value")
assert(dfOut.schema(1).name == "foo")
assert(actualOutput == expectedOutput)
}

test("Test withColumnIfNotExist() when the column exists") {
val expectedOutput =
"""+-----+
||value|
|+-----+
||1 |
||1 |
||1 |
||2 |
||1 |
|+-----+
|
|""".stripMargin.replace("\r\n", "\n")

val dfIn = getDummyDataFrame
val dfOut = dfIn.withColumnIfDoesNotExist((df: DataFrame, _) => df)("value", lit(1))
val actualOutput = dfOut.dataAsString(truncate = false)

assert(dfIn.schema.length == 1)
assert(dfIn.schema.head.name == "value")
assert(actualOutput == expectedOutput)
}

test("Test withColumnIfNotExist() when the column exists, but has a different case") {
val expectedOutput =
"""+-----+------+
||value|errCol|
|+-----+------+
||1 |[] |
||1 |[] |
||1 |[] |
||2 |[] |
||1 |[] |
|+-----+------+
|
|""".stripMargin.replace("\r\n", "\n")

val dfIn = getDummyDataFrame
val function: (DataFrame, String) => DataFrame = (df: DataFrame, _) => df.withColumn("errCol", lit(Array.emptyIntArray))
val dfOut = dfIn.withColumnIfDoesNotExist(function)("vAlUe", lit(1))
val actualOutput = dfOut.dataAsString(truncate = false)

assert(dfIn.schema.length == 1)
assert(dfIn.schema.head.name == "value")
assert(actualOutput == expectedOutput)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
* limitations under the License.
*/

package za.co.absa.spark.commons
package za.co.absa.spark.commons.schema

import org.apache.spark.sql.AnalysisException
import org.scalatest.BeforeAndAfterAll
Expand Down