Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 29 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,4 +50,32 @@ select to order and positionally filter columns of a DataFrame

```scala
SchemaUtils.alignSchema(dataFrameToBeAligned, modelSchema)
```
```
5. Getting a column with a unique name in case a schema is provided

```scala
SchemaUtils.getUniqueName(prefix, modelSchema)
```

### DataFrameImplicits
_DataFrameImplicits_ provides methods for transformations on Dataframes

1. Getting the string of the shown data of a dataframe

```scala
df.dataAsString()

df.dataAsString(truncate)

df.dataAsString(numRows, truncate)

df.dataAsString(numRows, truncateNumber)

df.dataAsString(numRows, truncate, vertical)
```

2. Adds a column to a dataframe if it does not exist. If it exists, it will add an error in the error column

```scala
df.withColumnIfDoesNotExist(colName, colExpression)
```
1 change: 1 addition & 0 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ libraryDependencies ++= List(
"org.apache.spark" %% "spark-sql" % sparkVersion % "provided",
"za.co.absa.commons" %% "commons" % "1.0.0",
"za.co.absa" %% "spark-hofs" % "0.4.0",
"za.co.absa" %% "spark-hats" % "0.2.2",
"org.scala-lang" % "scala-compiler" % scalaVersion.value,
"org.scalatest" %% "scalatest" % "3.1.0" % Test,
"org.scalatest" %% "scalatest-flatspec" % "3.2.0" % Test,
Expand Down
43 changes: 43 additions & 0 deletions src/main/scala/za/co/absa/spark/commons/error/ErrorMessage.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*
* Copyright 2021 ABSA Group Limited
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package za.co.absa.spark.commons.error

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.types.StructType

/**
* Case class to represent an error message
*
* @param errType - Type or source of the error
* @param errCode - Internal error code
* @param errMsg - Textual description of the error
* @param errCol - The name of the column where the error occurred
* @param rawValues - Sequence of raw values (which are the potential culprits of the error)
* @param mappings - Sequence of Mappings i.e Mapping Table Column -> Equivalent Mapped Dataset column
*/
case class ErrorMessage(errType: String, errCode: String, errMsg: String, errCol: String, rawValues: Seq[String], mappings: Seq[Mapping] = Seq())
case class Mapping(mappingTableColumn: String, mappedDatasetColumn: String)

object ErrorMessage {
val errorColumnName = "errCol"

def errorColSchema(implicit spark: SparkSession): StructType = {
import spark.implicits._
spark.emptyDataset[ErrorMessage].schema
}
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
/*
* Copyright 2021 ABSA Group Limited
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package za.co.absa.spark.commons.implicits

import java.io.ByteArrayOutputStream

import org.apache.log4j.{LogManager, Logger}
import org.apache.spark.sql.functions.{array, callUDF, col, lit, when}
import org.apache.spark.sql.{Column, DataFrame, SparkSession}
import za.co.absa.spark.commons.error.ErrorMessage
import za.co.absa.spark.commons.schema.SchemaUtils
import za.co.absa.spark.hats.transformations.NestedArrayTransformations

import scala.collection.mutable

object DataFrameImplicits {

private val log: Logger = LogManager.getLogger(this.getClass)

private val overWriteErrorFunction = "overWriteErr"
private val overWriteErrorType = "overWriteError"
private val overWriteErrorCode = "E000OW"

implicit class DataFrameEnhancements(val df: DataFrame) {

private def gatherData(showFnc: () => Unit): String = {
val outCapture = new ByteArrayOutputStream
Console.withOut(outCapture) {
showFnc()
}
val dfData = new String(outCapture.toByteArray).replace("\r\n", "\n")
dfData
}

def dataAsString(): String = {
val showFnc: () => Unit = df.show
gatherData(showFnc)
}

def dataAsString(truncate: Boolean): String = {
val showFnc: () => Unit = ()=>{df.show(truncate)}
gatherData(showFnc)
}

def dataAsString(numRows: Int, truncate: Boolean): String = {
val showFnc: ()=>Unit = () => df.show(numRows, truncate)
gatherData(showFnc)
}

def dataAsString(numRows: Int, truncate: Int): String = {
val showFnc: ()=>Unit = () => df.show(numRows, truncate)
gatherData(showFnc)
}

def dataAsString(numRows: Int, truncate: Int, vertical: Boolean): String = {
val showFnc: ()=>Unit = () => df.show(numRows, truncate, vertical)
gatherData(showFnc)
}

/**
* Adds a column to a dataframe if it does not exist
*
* @param colName A column to add if it does not exist already
* @param colExpr An expression for the column to add
* @return a new dataframe with the new column
*/
def withColumnIfDoesNotExist(colName: String, colExpr: Column): DataFrame = {
if (df.schema.exists(field => field.name.equalsIgnoreCase(colName))) {
log.warn(s"Column '$colName' already exists. The content of the column will be overwritten.")
overwriteWithErrorColumn(df, colName, colExpr)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This again is pretty Encaladus specific. I would make this branch of code and input parameter.

ifExists: (DataFrame, String) => Unit = (_, _) => {}

This would also eliminate the need for the specific errorColumn code

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So remove all the overwriteWithErrorColumn function?

} else {
df.withColumn(colName, colExpr)
}
}

/**
* Overwrites a column with a value provided by an expression.
* If the value in the column does not match the one provided by the expression, an error will be
* added to the error column.
*
* @param df A dataframe
* @param colName A column to be overwritten
* @param colExpr An expression for the value to write
* @return a new dataframe with the value of the column being overwritten
*/
private def overwriteWithErrorColumn(df: DataFrame, colName: String, colExpr: Column): DataFrame = {
implicit val spark: SparkSession = df.sparkSession
spark.udf.register(overWriteErrorFunction, { (errCol: String, rawValue: String) =>
ErrorMessage(
errType = overWriteErrorType,
errCode = overWriteErrorCode,
errMsg = "Special column value has changed",
errCol = errCol,
rawValues = Seq(rawValue))
})
spark.udf.register("arrayDistinctErrors", // this UDF is registered for _spark-hats_ library sake
(arr: mutable.WrappedArray[ErrorMessage]) =>
if (arr != null) {
arr.distinct.filter((a: AnyRef) => a != null)
} else {
Seq[ErrorMessage]()
}
)

val tmpColumn = SchemaUtils.getUniqueName("tmpColumn", Some(df.schema))
val tmpErrColumn = SchemaUtils.getUniqueName("tmpErrColumn", Some(df.schema))
val litErrUdfCall = callUDF(overWriteErrorFunction, lit(colName), col(tmpColumn))

// Rename the original column to a temporary name. We need it for comparison.
val dfWithColRenamed = df.withColumnRenamed(colName, tmpColumn)

// Add new column with the intended value
val dfWithIntendedColumn = dfWithColRenamed.withColumn(colName, colExpr)

// Add a temporary error column containing errors if the original value does not match the intended one
val dfWithErrorColumn = dfWithIntendedColumn
.withColumn(tmpErrColumn, array(when(col(tmpColumn) =!= colExpr, litErrUdfCall).otherwise(null))) // scalastyle:ignore null

// Gather all errors in errCol
val dfWithAggregatedErrColumn = NestedArrayTransformations
.gatherErrors(dfWithErrorColumn, tmpErrColumn, ErrorMessage.errorColumnName)

// Drop the temporary column
dfWithAggregatedErrColumn.drop(tmpColumn)
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@
* limitations under the License.
*/

package za.co.absa.spark.commons
package za.co.absa.spark.commons.schema

import org.apache.spark.sql.functions.{col, struct}
import org.apache.spark.sql.types.{ArrayType, DataType, StructField, StructType}
import org.apache.spark.sql.{Column, DataFrame}
import za.co.absa.spark.commons.adapters.HofsAdapter

import scala.util.Random

object SchemaUtils extends HofsAdapter {
/**
* Compares 2 array fields of a dataframe schema.
Expand Down Expand Up @@ -230,4 +232,26 @@ object SchemaUtils extends HofsAdapter {
private def getMapOfFields(schema: StructType): Map[String, StructField] = {
schema.map(field => field.name.toLowerCase() -> field).toMap
}

/**
* Generate a unique column name
*
* @param prefix A prefix to use for the column name
* @param schema An optional schema to validate if the column already exists (a very low probability)
* @return A name that can be used as a unique column name
*/
def getUniqueName(prefix: String, schema: Option[StructType]): String = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would make this an implicit on StructType. When not checked against a schema it's so trivial, I wouldn't event bother making it a common function.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It can be removed since it will be replaced by this

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure how the linked code solves for the need of unique column name? 🤔

schema match {
case None =>
s"${prefix}_${Random.nextLong().abs}"
case Some(sch) =>
var exists = true
var columnName = ""
while (exists) {
columnName = s"${prefix}_${Random.nextLong().abs}"
exists = sch.fields.exists(_.name.compareToIgnoreCase(columnName) == 0)
}
columnName
}
}
}
Loading