-
Notifications
You must be signed in to change notification settings - Fork 0
#17 DataFrameImplicits #18
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,43 @@ | ||
| /* | ||
| * Copyright 2021 ABSA Group Limited | ||
| * | ||
| * Licensed under the Apache License, Version 2.0 (the "License"); | ||
| * you may not use this file except in compliance with the License. | ||
| * You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package za.co.absa.spark.commons.error | ||
|
|
||
| import org.apache.spark.sql.SparkSession | ||
| import org.apache.spark.sql.types.StructType | ||
|
|
||
| /** | ||
| * Case class to represent an error message | ||
| * | ||
| * @param errType - Type or source of the error | ||
| * @param errCode - Internal error code | ||
| * @param errMsg - Textual description of the error | ||
| * @param errCol - The name of the column where the error occurred | ||
| * @param rawValues - Sequence of raw values (which are the potential culprits of the error) | ||
| * @param mappings - Sequence of Mappings i.e Mapping Table Column -> Equivalent Mapped Dataset column | ||
| */ | ||
| case class ErrorMessage(errType: String, errCode: String, errMsg: String, errCol: String, rawValues: Seq[String], mappings: Seq[Mapping] = Seq()) | ||
| case class Mapping(mappingTableColumn: String, mappedDatasetColumn: String) | ||
|
|
||
| object ErrorMessage { | ||
| val errorColumnName = "errCol" | ||
|
|
||
| def errorColSchema(implicit spark: SparkSession): StructType = { | ||
| import spark.implicits._ | ||
| spark.emptyDataset[ErrorMessage].schema | ||
| } | ||
| } | ||
|
|
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,143 @@ | ||
| /* | ||
| * Copyright 2021 ABSA Group Limited | ||
| * | ||
| * Licensed under the Apache License, Version 2.0 (the "License"); | ||
| * you may not use this file except in compliance with the License. | ||
| * You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package za.co.absa.spark.commons.implicits | ||
|
|
||
| import java.io.ByteArrayOutputStream | ||
|
|
||
| import org.apache.log4j.{LogManager, Logger} | ||
| import org.apache.spark.sql.functions.{array, callUDF, col, lit, when} | ||
| import org.apache.spark.sql.{Column, DataFrame, SparkSession} | ||
| import za.co.absa.spark.commons.error.ErrorMessage | ||
| import za.co.absa.spark.commons.schema.SchemaUtils | ||
| import za.co.absa.spark.hats.transformations.NestedArrayTransformations | ||
|
|
||
| import scala.collection.mutable | ||
|
|
||
| object DataFrameImplicits { | ||
|
|
||
| private val log: Logger = LogManager.getLogger(this.getClass) | ||
|
|
||
| private val overWriteErrorFunction = "overWriteErr" | ||
| private val overWriteErrorType = "overWriteError" | ||
| private val overWriteErrorCode = "E000OW" | ||
|
|
||
| implicit class DataFrameEnhancements(val df: DataFrame) { | ||
|
|
||
| private def gatherData(showFnc: () => Unit): String = { | ||
| val outCapture = new ByteArrayOutputStream | ||
| Console.withOut(outCapture) { | ||
| showFnc() | ||
| } | ||
| val dfData = new String(outCapture.toByteArray).replace("\r\n", "\n") | ||
| dfData | ||
| } | ||
|
|
||
| def dataAsString(): String = { | ||
| val showFnc: () => Unit = df.show | ||
| gatherData(showFnc) | ||
| } | ||
|
|
||
| def dataAsString(truncate: Boolean): String = { | ||
| val showFnc: () => Unit = ()=>{df.show(truncate)} | ||
| gatherData(showFnc) | ||
| } | ||
|
|
||
| def dataAsString(numRows: Int, truncate: Boolean): String = { | ||
| val showFnc: ()=>Unit = () => df.show(numRows, truncate) | ||
| gatherData(showFnc) | ||
| } | ||
|
|
||
| def dataAsString(numRows: Int, truncate: Int): String = { | ||
| val showFnc: ()=>Unit = () => df.show(numRows, truncate) | ||
| gatherData(showFnc) | ||
| } | ||
|
|
||
| def dataAsString(numRows: Int, truncate: Int, vertical: Boolean): String = { | ||
| val showFnc: ()=>Unit = () => df.show(numRows, truncate, vertical) | ||
| gatherData(showFnc) | ||
| } | ||
|
|
||
| /** | ||
| * Adds a column to a dataframe if it does not exist | ||
| * | ||
| * @param colName A column to add if it does not exist already | ||
| * @param colExpr An expression for the column to add | ||
| * @return a new dataframe with the new column | ||
| */ | ||
| def withColumnIfDoesNotExist(colName: String, colExpr: Column): DataFrame = { | ||
| if (df.schema.exists(field => field.name.equalsIgnoreCase(colName))) { | ||
| log.warn(s"Column '$colName' already exists. The content of the column will be overwritten.") | ||
| overwriteWithErrorColumn(df, colName, colExpr) | ||
|
||
| } else { | ||
| df.withColumn(colName, colExpr) | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * Overwrites a column with a value provided by an expression. | ||
| * If the value in the column does not match the one provided by the expression, an error will be | ||
| * added to the error column. | ||
| * | ||
| * @param df A dataframe | ||
| * @param colName A column to be overwritten | ||
| * @param colExpr An expression for the value to write | ||
| * @return a new dataframe with the value of the column being overwritten | ||
| */ | ||
| private def overwriteWithErrorColumn(df: DataFrame, colName: String, colExpr: Column): DataFrame = { | ||
| implicit val spark: SparkSession = df.sparkSession | ||
| spark.udf.register(overWriteErrorFunction, { (errCol: String, rawValue: String) => | ||
| ErrorMessage( | ||
| errType = overWriteErrorType, | ||
| errCode = overWriteErrorCode, | ||
| errMsg = "Special column value has changed", | ||
| errCol = errCol, | ||
| rawValues = Seq(rawValue)) | ||
| }) | ||
| spark.udf.register("arrayDistinctErrors", // this UDF is registered for _spark-hats_ library sake | ||
| (arr: mutable.WrappedArray[ErrorMessage]) => | ||
| if (arr != null) { | ||
| arr.distinct.filter((a: AnyRef) => a != null) | ||
| } else { | ||
| Seq[ErrorMessage]() | ||
| } | ||
| ) | ||
|
|
||
| val tmpColumn = SchemaUtils.getUniqueName("tmpColumn", Some(df.schema)) | ||
| val tmpErrColumn = SchemaUtils.getUniqueName("tmpErrColumn", Some(df.schema)) | ||
| val litErrUdfCall = callUDF(overWriteErrorFunction, lit(colName), col(tmpColumn)) | ||
|
|
||
| // Rename the original column to a temporary name. We need it for comparison. | ||
| val dfWithColRenamed = df.withColumnRenamed(colName, tmpColumn) | ||
|
|
||
| // Add new column with the intended value | ||
| val dfWithIntendedColumn = dfWithColRenamed.withColumn(colName, colExpr) | ||
|
|
||
| // Add a temporary error column containing errors if the original value does not match the intended one | ||
| val dfWithErrorColumn = dfWithIntendedColumn | ||
| .withColumn(tmpErrColumn, array(when(col(tmpColumn) =!= colExpr, litErrUdfCall).otherwise(null))) // scalastyle:ignore null | ||
|
|
||
| // Gather all errors in errCol | ||
| val dfWithAggregatedErrColumn = NestedArrayTransformations | ||
| .gatherErrors(dfWithErrorColumn, tmpErrColumn, ErrorMessage.errorColumnName) | ||
|
|
||
| // Drop the temporary column | ||
| dfWithAggregatedErrColumn.drop(tmpColumn) | ||
| } | ||
|
|
||
| } | ||
|
|
||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -14,13 +14,15 @@ | |
| * limitations under the License. | ||
| */ | ||
|
|
||
| package za.co.absa.spark.commons | ||
| package za.co.absa.spark.commons.schema | ||
|
|
||
| import org.apache.spark.sql.functions.{col, struct} | ||
| import org.apache.spark.sql.types.{ArrayType, DataType, StructField, StructType} | ||
| import org.apache.spark.sql.{Column, DataFrame} | ||
| import za.co.absa.spark.commons.adapters.HofsAdapter | ||
|
|
||
| import scala.util.Random | ||
|
|
||
| object SchemaUtils extends HofsAdapter { | ||
| /** | ||
| * Compares 2 array fields of a dataframe schema. | ||
|
|
@@ -230,4 +232,26 @@ object SchemaUtils extends HofsAdapter { | |
| private def getMapOfFields(schema: StructType): Map[String, StructField] = { | ||
| schema.map(field => field.name.toLowerCase() -> field).toMap | ||
| } | ||
|
|
||
| /** | ||
| * Generate a unique column name | ||
AdrianOlosutean marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| * | ||
| * @param prefix A prefix to use for the column name | ||
| * @param schema An optional schema to validate if the column already exists (a very low probability) | ||
| * @return A name that can be used as a unique column name | ||
| */ | ||
| def getUniqueName(prefix: String, schema: Option[StructType]): String = { | ||
|
||
| schema match { | ||
| case None => | ||
| s"${prefix}_${Random.nextLong().abs}" | ||
| case Some(sch) => | ||
| var exists = true | ||
| var columnName = "" | ||
| while (exists) { | ||
| columnName = s"${prefix}_${Random.nextLong().abs}" | ||
| exists = sch.fields.exists(_.name.compareToIgnoreCase(columnName) == 0) | ||
| } | ||
| columnName | ||
| } | ||
| } | ||
| } | ||
Uh oh!
There was an error while loading. Please reload this page.