From 9d437380d7304d7c20bcccd0ed95ad14341c0a7a Mon Sep 17 00:00:00 2001 From: Adrian Olosutean Date: Tue, 11 Jan 2022 11:18:15 +0100 Subject: [PATCH 01/19] #19 ColumnImplicits and StructFieldImplicits --- README.md | 53 +++++++++++++++ .../commons/implicits/ColumnImplicits.scala | 65 +++++++++++++++++++ .../implicits/StructFieldImplicits.scala | 48 ++++++++++++++ 3 files changed, 166 insertions(+) create mode 100644 src/main/scala/za/co/absa/spark/commons/implicits/ColumnImplicits.scala create mode 100644 src/main/scala/za/co/absa/spark/commons/implicits/StructFieldImplicits.scala diff --git a/README.md b/README.md index a2267190..ed2e88df 100644 --- a/README.md +++ b/README.md @@ -50,4 +50,57 @@ select to order and positionally filter columns of a DataFrame ```scala SchemaUtils.alignSchema(dataFrameToBeAligned, modelSchema) + ``` + +### ColumnImplicits + +_Column_ provides implicit methods for transforming Spark Columns + +1. Transforms the column into a booleaan column, checking if values are negative or positive infinity + + ```scala + column.isInfinite() + ``` +2. Returns column with requested substring. It shifts the substring indexation to be in accordance with Scala/ Java. + The provided starting position where to start the substring from, if negative it will be counted from end + + ```scala + column.zeroBasedSubstr(startPos) + ``` +3. Returns column with requested substring. It shifts the substring indexation to be in accordance with Scala/ Java. + If the provided starting position where to start the substring from is negative, it will be counted from end. + The length of the desired substring, if longer then the rest of the string, all the remaining characters are taken. + + + ```scala + column.zeroBasedSubstr(startPos, length) + ``` + +### StructFieldImplicits + +_StructFieldImplicits_ provides implicit methods for working with StructField objects. + +1. Gets the metadata String value given a key + + ```scala + structField.getMetadataString(key) + ``` + +2. Gets the metadata Char value given a key if the value is a single character String, it returns the char, + otherwise None + + ```scala + structField.getMetadataChar(key) + ``` + +3. Gets the metadata boolean value of a given key, given that it can be transformed into boolean + + ```scala + structField.getMetadataStringAsBoolean(key) + ``` + +4. Checks the structfield if it has the provided key, returns a boolean + + ```scala + structField.hasMetadataKey(key) ``` \ No newline at end of file diff --git a/src/main/scala/za/co/absa/spark/commons/implicits/ColumnImplicits.scala b/src/main/scala/za/co/absa/spark/commons/implicits/ColumnImplicits.scala new file mode 100644 index 00000000..00fce653 --- /dev/null +++ b/src/main/scala/za/co/absa/spark/commons/implicits/ColumnImplicits.scala @@ -0,0 +1,65 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.spark.commons.implicits + +import org.apache.spark.sql.Column +import org.apache.spark.sql.functions._ + +object ColumnImplicits { + implicit class ColumnEnhancements(column: Column) { + def isInfinite: Column = { + column.isin(Double.PositiveInfinity, Double.NegativeInfinity) + } + + /** + * Spark strings are based on 1 unlike scala. The function shifts the substring indexation to be in accordance with + * Scala/ Java. + * Another enhancement is, that the function allows a negative index, denoting counting of the index from back + * This version takes the substring from the startPos until the end. + * + * @param startPos the index (zero based) where to start the substring from, if negative it's counted from end + * @return column with requested substring + */ + def zeroBasedSubstr(startPos: Int): Column = { + if (startPos >= 0) { + zeroBasedSubstr(startPos, Int.MaxValue - startPos) + } else { + zeroBasedSubstr(startPos, -startPos) + } + } + + /** + * Spark strings are base on 1 unlike scala. The function shifts the substring indexation to be in accordance with + * Scala/ Java. + * Another enhancement is, that the function allows a negative index, denoting counting of the index from back + * This version takes the substring from the startPos and takes up to the given number of characters (less. + * + * @param startPos the index (zero based) where to start the substring from, if negative it's counted from end + * @param len length of the desired substring, if longer then the rest of the string, all the remaining characters are taken + * @return column with requested substring + */ + def zeroBasedSubstr(startPos: Int, len: Int): Column = { + if (startPos >= 0) { + column.substr(startPos + 1, len) + } else { + val startPosColumn = greatest(length(column) + startPos + 1, lit(1)) + val lenColumn = lit(len) + when(length(column) + startPos <= 0, length(column) + startPos).otherwise(0) + column.substr(startPosColumn, lenColumn) + } + } + } +} diff --git a/src/main/scala/za/co/absa/spark/commons/implicits/StructFieldImplicits.scala b/src/main/scala/za/co/absa/spark/commons/implicits/StructFieldImplicits.scala new file mode 100644 index 00000000..1610bde6 --- /dev/null +++ b/src/main/scala/za/co/absa/spark/commons/implicits/StructFieldImplicits.scala @@ -0,0 +1,48 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.spark.commons.implicits + +import org.apache.spark.sql.types._ +import scala.util.Try + +object StructFieldImplicits { + implicit class StructFieldEnhancements(val structField: StructField) { + def getMetadataString(key: String): Option[String] = { + Try(structField.metadata.getString(key)).toOption + } + + def getMetadataChar(key: String): Option[Char] = { + val resultString = Try(structField.metadata.getString(key)).toOption + resultString.flatMap { s => + if (s.length == 1) { + Option(s(0)) + } else { + None + } + } + } + + def getMetadataStringAsBoolean(key: String): Option[Boolean] = { + Try(structField.metadata.getString(key).toBoolean).toOption + } + + + def hasMetadataKey(key: String): Boolean = { + structField.metadata.contains(key) + } + } +} From b88af07f6a258abffe13fd6ad5c5b71a81bac3c2 Mon Sep 17 00:00:00 2001 From: Adrian Olosutean Date: Thu, 13 Jan 2022 11:31:58 +0100 Subject: [PATCH 02/19] #21 functions from enceladus SchemaUtils put into SchemaUtils and StructTypeImplicits --- .../implicits/StructFieldImplicits.scala | 16 + .../implicits/StructTypeImplicits.scala | 432 ++++++++++++++++++ .../spark/commons/schema/MetadataKeys.scala | 22 + .../commons/{ => schema}/SchemaUtils.scala | 143 +++++- .../absa/spark/commons/SchemaUtilSuite.scala | 95 ++++ .../absa/spark/commons/SchemaUtilsSpec.scala | 2 + .../implicits/StructFieldImplicitsTest.scala | 36 ++ .../implicits/StructTypeImplicitsTest.scala | 390 ++++++++++++++++ 8 files changed, 1134 insertions(+), 2 deletions(-) create mode 100644 src/main/scala/za/co/absa/spark/commons/implicits/StructTypeImplicits.scala create mode 100644 src/main/scala/za/co/absa/spark/commons/schema/MetadataKeys.scala rename src/main/scala/za/co/absa/spark/commons/{ => schema}/SchemaUtils.scala (67%) create mode 100644 src/test/scala/za/co/absa/spark/commons/SchemaUtilSuite.scala create mode 100644 src/test/scala/za/co/absa/spark/commons/implicits/StructFieldImplicitsTest.scala create mode 100644 src/test/scala/za/co/absa/spark/commons/implicits/StructTypeImplicitsTest.scala diff --git a/src/main/scala/za/co/absa/spark/commons/implicits/StructFieldImplicits.scala b/src/main/scala/za/co/absa/spark/commons/implicits/StructFieldImplicits.scala index 1610bde6..75d2f4bb 100644 --- a/src/main/scala/za/co/absa/spark/commons/implicits/StructFieldImplicits.scala +++ b/src/main/scala/za/co/absa/spark/commons/implicits/StructFieldImplicits.scala @@ -17,6 +17,8 @@ package za.co.absa.spark.commons.implicits import org.apache.spark.sql.types._ +import za.co.absa.spark.commons.schema.MetadataKeys + import scala.util.Try object StructFieldImplicits { @@ -44,5 +46,19 @@ object StructFieldImplicits { def hasMetadataKey(key: String): Boolean = { structField.metadata.contains(key) } + + /** + * Determine the name of a field + * Will override to "sourcecolumn" in the Metadata if it exists + * + * @return Metadata "sourcecolumn" if it exists or field.name + */ + def getFieldNameOverriddenByMetadata(): String = { + if (structField.metadata.contains(MetadataKeys.SourceColumn)) { + structField.metadata.getString(MetadataKeys.SourceColumn) + } else { + structField.name + } + } } } diff --git a/src/main/scala/za/co/absa/spark/commons/implicits/StructTypeImplicits.scala b/src/main/scala/za/co/absa/spark/commons/implicits/StructTypeImplicits.scala new file mode 100644 index 00000000..b50f0404 --- /dev/null +++ b/src/main/scala/za/co/absa/spark/commons/implicits/StructTypeImplicits.scala @@ -0,0 +1,432 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.spark.commons.implicits + +import org.apache.spark.sql.types.{ArrayType, DataType, StructField, StructType} +import za.co.absa.spark.commons.schema.MetadataKeys +import za.co.absa.spark.commons.schema.SchemaUtils.{appendPath, getAllArraySubPaths, isCommonSubPath} + +import scala.annotation.tailrec +import scala.util.Try + +object StructTypeImplicits { + implicit class StructTypeEnhancements(val schema: StructType) { + /** + * Get a field from a text path and a given schema + * @param path The dot-separated path to the field + * @return Some(the requested field) or None if the field does not exist + */ + def getField(path: String): Option[StructField] = { + + @tailrec + def goThroughArrayDataType(dataType: DataType): DataType = { + dataType match { + case ArrayType(dt, _) => goThroughArrayDataType(dt) + case result => result + } + } + + @tailrec + def examineStructField(names: List[String], structField: StructField): Option[StructField] = { + if (names.isEmpty) { + Option(structField) + } else { + structField.dataType match { + case struct: StructType => examineStructField(names.tail, struct(names.head)) + case ArrayType(el: DataType, _) => + goThroughArrayDataType(el) match { + case struct: StructType => examineStructField(names.tail, struct(names.head)) + case _ => None + } + case _ => None + } + } + } + + val pathTokens = path.split('.').toList + Try{ + examineStructField(pathTokens.tail, schema(pathTokens.head)) + }.getOrElse(None) + } + + /** + * Get a type of a field from a text path and a given schema + * + * @param path The dot-separated path to the field + * @return Some(the type of the field) or None if the field does not exist + */ + def getFieldType(path: String): Option[DataType] = { + getField(path).map(_.dataType) + } + + /** + * Checks if the specified path is an array of structs + * + * @param path The dot-separated path to the field + * @return true if the field is an array of structs + */ + def isColumnArrayOfStruct(path: String): Boolean = { + getFieldType(path) match { + case Some(dt) => + dt match { + case arrayType: ArrayType => + arrayType.elementType match { + case _: StructType => true + case _ => false + } + case _ => false + } + case None => false + } + } + + /** + * Get nullability of a field from a text path and a given schema + * + * @param path The dot-separated path to the field + * @return Some(nullable) or None if the field does not exist + */ + def getFieldNullability(path: String): Option[Boolean] = { + getField(path).map(_.nullable) + } + + /** + * Checks if a field specified by a path and a schema exists + * @param path The dot-separated path to the field + * @return True if the field exists false otherwise + */ + def fieldExists(path: String): Boolean = { + getField(path).nonEmpty + } + /** + * Returns all renames in the provided schema. + * @param includeIfPredecessorChanged if set to true, fields are included even if their name have not changed but + * a predecessor's (parent, grandparent etc.) has + * @return the keys of the returned map are the columns' names after renames, the values are the source columns; + * the name are full paths denoted with dot notation + */ + def getRenamesInSchema(includeIfPredecessorChanged: Boolean = true): Map[String, String] = { + + def getRenamesRecursively(path: String, + sourcePath: String, + struct: StructType, + renamesAcc: Map[String, String], + predecessorChanged: Boolean): Map[String, String] = { + import za.co.absa.spark.commons.implicits.StructFieldImplicits.StructFieldEnhancements + + struct.fields.foldLeft(renamesAcc) { (renamesSoFar, field) => + val fieldFullName = appendPath(path, field.name) + val fieldSourceName = field.getMetadataString(MetadataKeys.SourceColumn).getOrElse(field.name) + val fieldFullSourceName = appendPath(sourcePath, fieldSourceName) + + val (renames, renameOnPath) = if ((fieldSourceName != field.name) || (predecessorChanged && includeIfPredecessorChanged)) { + (renamesSoFar + (fieldFullName -> fieldFullSourceName), true) + } else { + (renamesSoFar, predecessorChanged) + } + + field.dataType match { + case st: StructType => getRenamesRecursively(fieldFullName, fieldFullSourceName, st, renames, renameOnPath) + case at: ArrayType => getStructInArray(at.elementType).fold(renames) { item => + getRenamesRecursively(fieldFullName, fieldFullSourceName, item, renames, renameOnPath) + } + case _ => renames + } + } + } + + @tailrec + def getStructInArray(dataType: DataType): Option[StructType] = { + dataType match { + case st: StructType => Option(st) + case at: ArrayType => getStructInArray(at.elementType) + case _ => None + } + } + + getRenamesRecursively("", "", schema, Map.empty, predecessorChanged = false) + } + + /** + * Get first array column's path out of complete path. + * + * E.g if the path argument is "a.b.c.d.e" where b and d are arrays, "a.b" will be returned. + * + * @param path The path to the attribute + * @return The path of the first array field or "" if none were found + */ + def getFirstArrayPath(path: String): String = { + @tailrec + def helper(remPath: Seq[String], pathAcc: Seq[String]): Seq[String] = { + if (remPath.isEmpty) Seq() else { + val currPath = (pathAcc :+ remPath.head).mkString(".") + val currType = getFieldType(currPath) + currType match { + case Some(_: ArrayType) => pathAcc :+ remPath.head + case Some(_) => helper(remPath.tail, pathAcc :+ remPath.head) + case None => Seq() + } + } + } + + val pathToks = path.split('.') + helper(pathToks, Seq()).mkString(".") + } + + /** + * Get all array columns' paths out of complete path. + * + * E.g. if the path argument is "a.b.c.d.e" where b and d are arrays, "a.b" and "a.b.c.d" will be returned. + * + * @param path The path to the attribute + * @return Seq of dot-separated paths for all array fields in the provided path + */ + def getAllArraysInPath(path: String): Seq[String] = { + @tailrec + def helper(remPath: Seq[String], pathAcc: Seq[String], arrayAcc: Seq[String]): Seq[String] = { + if (remPath.isEmpty) arrayAcc else { + val currPath = (pathAcc :+ remPath.head).mkString(".") + val currType = getFieldType(currPath) + currType match { + case Some(_: ArrayType) => + val strings = pathAcc :+ remPath.head + helper(remPath.tail, strings, arrayAcc :+ strings.mkString(".")) + case Some(_) => helper(remPath.tail, pathAcc :+ remPath.head, arrayAcc) + case None => arrayAcc + } + } + } + + val pathToks = path.split("\\.") + helper(pathToks, Seq(), Seq()) + } + + /** + * For a given list of field paths determines the deepest common array path. + * + * For instance, if given 'a.b', 'a.b.c', 'a.b.c.d' where b and c are arrays the common deepest array + * path is 'a.b.c'. + * + * If any of the arrays are on diverging paths this function returns None. + * + * The purpose of the function is to determine the order of explosions to be made before the dataframe can be + * joined on a field inside an array. + * + * @param fieldPaths A list of paths to analyze + * @return Returns a common array path if there is one and None if any of the arrays are on diverging paths + */ + def getDeepestCommonArrayPath(fieldPaths: Seq[String]): Option[String] = { + val arrayPaths = fieldPaths.flatMap(path => getAllArraysInPath(path)).distinct + + if (arrayPaths.nonEmpty && isCommonSubPath(arrayPaths: _*)) { + Some(arrayPaths.maxBy(_.length)) + } else { + None + } + } + + /** + * For a field path determines the deepest array path. + * + * For instance, if given 'a.b.c.d' where b and c are arrays the deepest array is 'a.b.c'. + * + * @param fieldPath A path to analyze + * @return Returns a common array path if there is one and None if any of the arrays are on diverging paths + */ + def getDeepestArrayPath(fieldPath: String): Option[String] = { + val arrayPaths = getAllArraysInPath(fieldPath) + + if (arrayPaths.nonEmpty) { + Some(arrayPaths.maxBy(_.length)) + } else { + None + } + } + + /** + * Determine the name of a field + * Will override to "sourcecolumn" in the Metadata if it exists + * + * @param field field to work with + * @return Metadata "sourcecolumn" if it exists or field.name + */ + def getFieldNameOverriddenByMetadata(field: StructField): String = { + if (field.metadata.contains(MetadataKeys.SourceColumn)) { + field.metadata.getString(MetadataKeys.SourceColumn) + } else { + field.name + } + } + + /** + * Get paths for all array fields in the schema + * + * @return Seq of dot separated paths of fields in the schema, which are of type Array + */ + def getAllArrayPaths(): Seq[String] = { + schema.fields.flatMap(f => getAllArraySubPaths("", f.name, f.dataType)).toSeq + } + + /** + * Get a closest unique column name + * + * @param desiredName A prefix to use for the column name + * @param schema A schema to validate if the column already exists + * @return A name that can be used as a unique column name + */ + def getClosestUniqueName(desiredName: String, schema: StructType): String = { + var exists = true + var columnName = "" + var i = 0 + while (exists) { + columnName = if (i == 0) desiredName else s"${desiredName}_$i" + exists = schema.fields.exists(_.name.compareToIgnoreCase(columnName) == 0) + i += 1 + } + columnName + } + + /** + * Checks if a field is the only field in a struct + * + * @param column A column to check + * @return true if the column is the only column in a struct + */ + def isOnlyField(column: String): Boolean = { + def structHelper(structField: StructType, path: Seq[String]): Boolean = { + val currentField = path.head + val isLeaf = path.lengthCompare(1) <= 0 + var isOnlyField = false + structField.fields.foreach(field => + if (field.name == currentField) { + if (isLeaf) { + isOnlyField = structField.fields.length == 1 + } else { + field.dataType match { + case st: StructType => + isOnlyField = structHelper(st, path.tail) + case _: ArrayType => + throw new IllegalArgumentException( + s"SchemaUtils.isOnlyField() does not support checking struct fields inside an array") + case _ => + throw new IllegalArgumentException( + s"Primitive fields cannot have child fields $currentField is a primitive in $column") + } + } + } + ) + isOnlyField + } + val path = column.split('.') + structHelper(schema, path) + } + + /** + * Checks if a field is an array that is not nested in another array + * + * @param fieldPathName A field to check + * @return true if a field is an array that is not nested in another array + */ + def isNonNestedArray(fieldPathName: String): Boolean = { + def structHelper(structField: StructType, path: Seq[String]): Boolean = { + val currentField = path.head + val isLeaf = path.lengthCompare(1) <= 0 + var isArray = false + structField.fields.foreach(field => + if (field.name == currentField) { + field.dataType match { + case st: StructType => + if (!isLeaf) { + isArray = structHelper(st, path.tail) + } + case _: ArrayType => + if (isLeaf) { + isArray = true + } + case _ => + if (!isLeaf) { + throw new IllegalArgumentException( + s"Primitive fields cannot have child fields $currentField is a primitive in $fieldPathName") + } + } + } + ) + isArray + } + + val path = fieldPathName.split('.') + structHelper(schema, path) + } + + /** + * Checks if a field is an array + * + * @param fieldPathName A field to check + * @return true if the specified field is an array + */ + def isArray(fieldPathName: String): Boolean = { + @tailrec + def arrayHelper(arrayField: ArrayType, path: Seq[String]): Boolean = { + val currentField = path.head + val isLeaf = path.lengthCompare(1) <= 0 + + arrayField.elementType match { + case st: StructType => structHelper(st, path.tail) + case ar: ArrayType => arrayHelper(ar, path) + case _ => + if (!isLeaf) { + throw new IllegalArgumentException( + s"Primitive fields cannot have child fields $currentField is a primitive in $fieldPathName") + } + false + } + } + + def structHelper(structField: StructType, path: Seq[String]): Boolean = { + val currentField = path.head + val isLeaf = path.lengthCompare(1) <= 0 + var isArray = false + structField.fields.foreach(field => + if (field.name == currentField) { + field.dataType match { + case st: StructType => + if (!isLeaf) { + isArray = structHelper(st, path.tail) + } + case ar: ArrayType => + if (isLeaf) { + isArray = true + } else { + isArray = arrayHelper(ar, path) + } + case _ => + if (!isLeaf) { + throw new IllegalArgumentException( + s"Primitive fields cannot have child fields $currentField is a primitive in $fieldPathName") + } + } + } + ) + isArray + } + + val path = fieldPathName.split('.') + structHelper(schema, path) + } + } + +} diff --git a/src/main/scala/za/co/absa/spark/commons/schema/MetadataKeys.scala b/src/main/scala/za/co/absa/spark/commons/schema/MetadataKeys.scala new file mode 100644 index 00000000..dc6a7aeb --- /dev/null +++ b/src/main/scala/za/co/absa/spark/commons/schema/MetadataKeys.scala @@ -0,0 +1,22 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.spark.commons.schema + +object MetadataKeys { + // all + val SourceColumn = "sourcecolumn" +} diff --git a/src/main/scala/za/co/absa/spark/commons/SchemaUtils.scala b/src/main/scala/za/co/absa/spark/commons/schema/SchemaUtils.scala similarity index 67% rename from src/main/scala/za/co/absa/spark/commons/SchemaUtils.scala rename to src/main/scala/za/co/absa/spark/commons/schema/SchemaUtils.scala index c3300a41..9096958c 100644 --- a/src/main/scala/za/co/absa/spark/commons/SchemaUtils.scala +++ b/src/main/scala/za/co/absa/spark/commons/schema/SchemaUtils.scala @@ -14,14 +14,44 @@ * limitations under the License. */ -package za.co.absa.spark.commons +package za.co.absa.spark.commons.schema import org.apache.spark.sql.functions.{col, struct} -import org.apache.spark.sql.types.{ArrayType, DataType, StructField, StructType} +import org.apache.spark.sql.types._ import org.apache.spark.sql.{Column, DataFrame} import za.co.absa.spark.commons.adapters.HofsAdapter +import scala.annotation.tailrec + object SchemaUtils extends HofsAdapter { + + /** + * Returns the parent path of a field. Returns an empty string if a root level field name is provided. + * + * @param columnName A fully qualified column name + * @return The parent column name or an empty string if the input column is a root level column + */ + def getParentPath(columnName: String): String = { + val index = columnName.lastIndexOf('.') + if (index > 0) { + columnName.substring(0, index) + } else { + "" + } + } + + /** + * Converts a fully qualified field name (including its path, e.g. containing fields) to a unique field name without + * dot notation + * + * @param path the fully qualified field name + * @return unique top level field name + */ + def unpath(path: String): String = { + path.replace("_", "__") + .replace('.', '_') + } + /** * Compares 2 array fields of a dataframe schema. * @@ -90,6 +120,40 @@ object SchemaUtils extends HofsAdapter { } } + /** + * Determine if a datatype is a primitive one + */ + def isPrimitive(dt: DataType): Boolean = dt match { + case _: BinaryType + | _: BooleanType + | _: ByteType + | _: DateType + | _: DecimalType + | _: DoubleType + | _: FloatType + | _: IntegerType + | _: LongType + | _: NullType + | _: ShortType + | _: StringType + | _: TimestampType => true + case _ => false + } + + /** + * For an array of arrays of arrays, ... get the final element type at the bottom of the array + * + * @param arrayType An array data type from a Spark dataframe schema + * @return A non-array data type at the bottom of array nesting + */ + @tailrec + final def getDeepestArrayType(arrayType: ArrayType): DataType = { + arrayType.elementType match { + case a: ArrayType => getDeepestArrayType(a) + case b => b + } + } + /** * Finds all differences of two StructFields and returns their paths * @@ -230,4 +294,79 @@ object SchemaUtils extends HofsAdapter { private def getMapOfFields(schema: StructType): Map[String, StructField] = { schema.map(field => field.name.toLowerCase() -> field).toMap } + + /** + * Get paths for all array subfields of this given datatype + */ + def getAllArraySubPaths(path: String, name: String, dt: DataType): Seq[String] = { + val currPath = appendPath(path, name) + dt match { + case s: StructType => s.fields.flatMap(f => getAllArraySubPaths(currPath, f.name, f.dataType)) + case _@ArrayType(elType, _) => getAllArraySubPaths(path, name, elType) :+ currPath + case _ => Seq() + } + } + + /** + * For a given list of field paths determines if any path pair is a subset of one another. + * + * For instance, + * - 'a.b', 'a.b.c', 'a.b.c.d' have this property. + * - 'a.b', 'a.b.c', 'a.x.y' does NOT have it, since 'a.b.c' and 'a.x.y' have diverging subpaths. + * + * @param paths A list of paths to be analyzed + * @return true if for all pathe the above property holds + */ + def isCommonSubPath(paths: String*): Boolean = { + def sliceRoot(paths: Seq[Seq[String]]): Seq[Seq[String]] = { + paths.map(path => path.drop(1)).filter(_.nonEmpty) + } + + var isParentCommon = true // For Seq() the property holds by [my] convention + var restOfPaths: Seq[Seq[String]] = paths.map(_.split('.').toSeq).filter(_.nonEmpty) + while (isParentCommon && restOfPaths.nonEmpty) { + val parent = restOfPaths.head.head + isParentCommon = restOfPaths.forall(path => path.head == parent) + restOfPaths = sliceRoot(restOfPaths) + } + isParentCommon + } + + /** + * Append a new attribute to path or empty string. + * + * @param path The dot-separated existing path + * @param fieldName Name of the field to be appended to the path + * @return The path with the new field appended or the field itself if path is empty + */ + def appendPath(path: String, fieldName: String): String = { + if (path.isEmpty) { + fieldName + } else if (fieldName.isEmpty) { + path + } else { + s"$path.$fieldName" + } + } + + + /** + * Checks if a casting between types always succeeds + * + * @param sourceType A type to be casted + * @param targetType A type to be casted to + * @return true if casting never fails + */ + def isCastAlwaysSucceeds(sourceType: DataType, targetType: DataType): Boolean = { + (sourceType, targetType) match { + case (_: StructType, _) | (_: ArrayType, _) => false + case (a, b) if a == b => true + case (_, _: StringType) => true + case (_: ByteType, _: ShortType | _: IntegerType | _: LongType) => true + case (_: ShortType, _: IntegerType | _: LongType) => true + case (_: IntegerType, _: LongType) => true + case (_: DateType, _: TimestampType) => true + case _ => false + } + } } diff --git a/src/test/scala/za/co/absa/spark/commons/SchemaUtilSuite.scala b/src/test/scala/za/co/absa/spark/commons/SchemaUtilSuite.scala new file mode 100644 index 00000000..38fd7b1a --- /dev/null +++ b/src/test/scala/za/co/absa/spark/commons/SchemaUtilSuite.scala @@ -0,0 +1,95 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.spark.commons + +import org.apache.spark.sql.types.{ArrayType, ByteType, DateType, DecimalType, IntegerType, LongType, MetadataBuilder, ShortType, StringType, StructField, StructType, TimestampType} +import org.scalatest.funsuite.AnyFunSuite +import org.scalatest.matchers.should.Matchers +import za.co.absa.spark.commons.schema.SchemaUtils._ + +class SchemaUtilSuite extends AnyFunSuite with Matchers { + // scalastyle:off magic.number + + + test ("Test isCastAlwaysSucceeds()") { + assert(!isCastAlwaysSucceeds(StructType(Seq()), StringType)) + assert(!isCastAlwaysSucceeds(ArrayType(StringType), StringType)) + assert(!isCastAlwaysSucceeds(StringType, ByteType)) + assert(!isCastAlwaysSucceeds(StringType, ShortType)) + assert(!isCastAlwaysSucceeds(StringType, IntegerType)) + assert(!isCastAlwaysSucceeds(StringType, LongType)) + assert(!isCastAlwaysSucceeds(StringType, DecimalType(10,10))) + assert(!isCastAlwaysSucceeds(StringType, DateType)) + assert(!isCastAlwaysSucceeds(StringType, TimestampType)) + assert(!isCastAlwaysSucceeds(StructType(Seq()), StructType(Seq()))) + assert(!isCastAlwaysSucceeds(ArrayType(StringType), ArrayType(StringType))) + + assert(!isCastAlwaysSucceeds(ShortType, ByteType)) + assert(!isCastAlwaysSucceeds(IntegerType, ByteType)) + assert(!isCastAlwaysSucceeds(IntegerType, ShortType)) + assert(!isCastAlwaysSucceeds(LongType, ByteType)) + assert(!isCastAlwaysSucceeds(LongType, ShortType)) + assert(!isCastAlwaysSucceeds(LongType, IntegerType)) + + assert(isCastAlwaysSucceeds(StringType, StringType)) + assert(isCastAlwaysSucceeds(ByteType, StringType)) + assert(isCastAlwaysSucceeds(ShortType, StringType)) + assert(isCastAlwaysSucceeds(IntegerType, StringType)) + assert(isCastAlwaysSucceeds(LongType, StringType)) + assert(isCastAlwaysSucceeds(DecimalType(10,10), StringType)) + assert(isCastAlwaysSucceeds(DateType, StringType)) + assert(isCastAlwaysSucceeds(TimestampType, StringType)) + assert(isCastAlwaysSucceeds(StringType, StringType)) + + assert(isCastAlwaysSucceeds(ByteType, ByteType)) + assert(isCastAlwaysSucceeds(ByteType, ShortType)) + assert(isCastAlwaysSucceeds(ByteType, IntegerType)) + assert(isCastAlwaysSucceeds(ByteType, LongType)) + assert(isCastAlwaysSucceeds(ShortType, ShortType)) + assert(isCastAlwaysSucceeds(ShortType, IntegerType)) + assert(isCastAlwaysSucceeds(ShortType, LongType)) + assert(isCastAlwaysSucceeds(IntegerType, IntegerType)) + assert(isCastAlwaysSucceeds(IntegerType, LongType)) + assert(isCastAlwaysSucceeds(LongType, LongType)) + assert(isCastAlwaysSucceeds(DateType, TimestampType)) + } + + test("Test isCommonSubPath()") { + assert (isCommonSubPath()) + assert (isCommonSubPath("a")) + assert (isCommonSubPath("a.b.c.d.e.f", "a.b.c.d", "a.b.c", "a.b", "a")) + assert (!isCommonSubPath("a.b.c.d.e.f", "a.b.c.x", "a.b.c", "a.b", "a")) + } + + test("unpath - empty string remains empty") { + val result = unpath("") + val expected = "" + assert(result == expected) + } + + test("unpath - underscores get doubled") { + val result = unpath("one_two__three") + val expected = "one__two____three" + assert(result == expected) + } + + test("unpath - dot notation conversion") { + val result = unpath("grand_parent.parent.first_child") + val expected = "grand__parent_parent_first__child" + assert(result == expected) + } +} diff --git a/src/test/scala/za/co/absa/spark/commons/SchemaUtilsSpec.scala b/src/test/scala/za/co/absa/spark/commons/SchemaUtilsSpec.scala index 9cf1cf7c..95d08e4d 100644 --- a/src/test/scala/za/co/absa/spark/commons/SchemaUtilsSpec.scala +++ b/src/test/scala/za/co/absa/spark/commons/SchemaUtilsSpec.scala @@ -17,9 +17,11 @@ package za.co.absa.spark.commons import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.types.{ArrayType, ByteType, DateType, DecimalType, IntegerType, LongType, MetadataBuilder, ShortType, StringType, StructField, StructType, TimestampType} import org.scalatest.BeforeAndAfterAll import org.scalatest.flatspec.AnyFlatSpec import org.scalatest.matchers.should.Matchers +import za.co.absa.spark.commons.schema.SchemaUtils import za.co.absa.spark.commons.test.SparkTestBase class SchemaUtilsSpec extends AnyFlatSpec with Matchers with BeforeAndAfterAll with SparkTestBase { diff --git a/src/test/scala/za/co/absa/spark/commons/implicits/StructFieldImplicitsTest.scala b/src/test/scala/za/co/absa/spark/commons/implicits/StructFieldImplicitsTest.scala new file mode 100644 index 00000000..5c33b675 --- /dev/null +++ b/src/test/scala/za/co/absa/spark/commons/implicits/StructFieldImplicitsTest.scala @@ -0,0 +1,36 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.spark.commons.implicits + +import org.apache.spark.sql.types.{IntegerType, MetadataBuilder, StructField} +import org.scalatest.funsuite.AnyFunSuite +import za.co.absa.spark.commons.implicits.StructFieldImplicits.StructFieldEnhancements + +class StructFieldImplicitsTest extends AnyFunSuite { + + private val structFieldNoMetadata = StructField("a", IntegerType) + private val structFieldWithMetadataNotSourceColumn = StructField("a", IntegerType, nullable = false, new MetadataBuilder().putString("meta", "data").build) + private val structFieldWithMetadataSourceColumn = StructField("a", IntegerType, nullable = false, new MetadataBuilder().putString("sourcecolumn", "override_a").build) + + + test("Testing getFieldNameOverriddenByMetadata") { + assertResult("a")(structFieldNoMetadata.getFieldNameOverriddenByMetadata()) + assertResult("a")(structFieldWithMetadataNotSourceColumn.getFieldNameOverriddenByMetadata()) + assertResult("override_a")(structFieldWithMetadataSourceColumn.getFieldNameOverriddenByMetadata()) + } + +} diff --git a/src/test/scala/za/co/absa/spark/commons/implicits/StructTypeImplicitsTest.scala b/src/test/scala/za/co/absa/spark/commons/implicits/StructTypeImplicitsTest.scala new file mode 100644 index 00000000..afa68008 --- /dev/null +++ b/src/test/scala/za/co/absa/spark/commons/implicits/StructTypeImplicitsTest.scala @@ -0,0 +1,390 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.spark.commons.implicits + +import org.apache.spark.sql.types.{ArrayType, ByteType, DateType, DecimalType, IntegerType, LongType, MetadataBuilder, ShortType, StringType, StructField, StructType, TimestampType} +import org.scalatest.funsuite.AnyFunSuite +import za.co.absa.spark.commons.implicits.StructTypeImplicits.StructTypeEnhancements + +class StructTypeImplicitsTest extends AnyFunSuite { + // scalastyle:off magic.number + + private val schema = StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", StructType(Seq( + StructField("c", IntegerType), + StructField("d", StructType(Seq( + StructField("e", IntegerType))), nullable = true)))), + StructField("f", StructType(Seq( + StructField("g", ArrayType.apply(StructType(Seq( + StructField("h", IntegerType)))))))))) + + private val nestedSchema = StructType(Seq( + StructField("a", IntegerType), + StructField("b", ArrayType(StructType(Seq( + StructField("c", StructType(Seq( + StructField("d", ArrayType(StructType(Seq( + StructField("e", IntegerType)))))))))))))) + + private val arrayOfArraysSchema = StructType(Seq( + StructField("a", ArrayType(ArrayType(IntegerType)), nullable = false), + StructField("b", ArrayType(ArrayType(StructType(Seq( + StructField("c", StringType, nullable = false) + )) + )), nullable = true) + )) + + private val structFieldNoMetadata = StructField("a", IntegerType) + + private val structFieldWithMetadataNotSourceColumn = StructField("a", IntegerType, nullable = false, new MetadataBuilder().putString("meta", "data").build) + private val structFieldWithMetadataSourceColumn = StructField("a", IntegerType, nullable = false, new MetadataBuilder().putString("sourcecolumn", "override_a").build) + + test("Testing getFieldType") { + + val a = schema.getFieldType("a") + val b = schema.getFieldType("b") + val c = schema.getFieldType("b.c") + val d = schema.getFieldType("b.d") + val e = schema.getFieldType("b.d.e") + val f = schema.getFieldType("f") + val g = schema.getFieldType("f.g") + val h = schema.getFieldType("f.g.h") + + assert(a.get.isInstanceOf[IntegerType]) + assert(b.get.isInstanceOf[StructType]) + assert(c.get.isInstanceOf[IntegerType]) + assert(d.get.isInstanceOf[StructType]) + assert(e.get.isInstanceOf[IntegerType]) + assert(f.get.isInstanceOf[StructType]) + assert(g.get.isInstanceOf[ArrayType]) + assert(h.get.isInstanceOf[IntegerType]) + assert(schema.getFieldType("z").isEmpty) + assert(schema.getFieldType("x.y.z").isEmpty) + assert(schema.getFieldType("f.g.h.a").isEmpty) + } + + test("Testing fieldExists") { + assert(schema.fieldExists("a")) + assert(schema.fieldExists("b")) + assert(schema.fieldExists("b.c")) + assert(schema.fieldExists("b.d")) + assert(schema.fieldExists("b.d.e")) + assert(schema.fieldExists("f")) + assert(schema.fieldExists("f.g")) + assert(schema.fieldExists("f.g.h")) + assert(!schema.fieldExists("z")) + assert(!schema.fieldExists("x.y.z")) + assert(!schema.fieldExists("f.g.h.a")) + } + + test ("Test isColumnArrayOfStruct") { + assert(!schema.isColumnArrayOfStruct("a")) + assert(!schema.isColumnArrayOfStruct("b")) + assert(!schema.isColumnArrayOfStruct("b.c")) + assert(!schema.isColumnArrayOfStruct("b.d")) + assert(!schema.isColumnArrayOfStruct("b.d.e")) + assert(!schema.isColumnArrayOfStruct("f")) + assert(schema.isColumnArrayOfStruct("f.g")) + assert(!schema.isColumnArrayOfStruct("f.g.h")) + assert(!nestedSchema.isColumnArrayOfStruct("a")) + assert(nestedSchema.isColumnArrayOfStruct("b")) + assert(nestedSchema.isColumnArrayOfStruct("b.c.d")) + } + + test("getRenamesInSchema - no renames") { + val schema = StructType(Seq( + structFieldNoMetadata, + structFieldWithMetadataNotSourceColumn)) + val result = schema.getRenamesInSchema() + assert(result.isEmpty) + } + + test("getRenamesInSchema - simple rename") { + val schema = StructType(Seq(structFieldWithMetadataSourceColumn)) + val result = schema.getRenamesInSchema() + assert(result == Map("a" -> "override_a")) + + } + + test("getRenamesInSchema - complex with includeIfPredecessorChanged set") { + val sub = StructType(Seq( + StructField("d", IntegerType, nullable = false, new MetadataBuilder().putString("sourcecolumn", "o").build), + StructField("e", IntegerType, nullable = false, new MetadataBuilder().putString("sourcecolumn", "e").build), + StructField("f", IntegerType) + )) + val schema = StructType(Seq( + StructField("a", sub, nullable = false, new MetadataBuilder().putString("sourcecolumn", "x").build), + StructField("b", sub, nullable = false, new MetadataBuilder().putString("sourcecolumn", "b").build), + StructField("c", sub) + )) + + val includeIfPredecessorChanged = true + val result = schema.getRenamesInSchema(includeIfPredecessorChanged) + val expected = Map( + "a" -> "x" , + "a.d" -> "x.o", + "a.e" -> "x.e", + "a.f" -> "x.f", + "b.d" -> "b.o", + "c.d" -> "c.o" + ) + + assert(result == expected) + } + + test("getRenamesInSchema - complex with includeIfPredecessorChanged not set") { + val sub = StructType(Seq( + StructField("d", IntegerType, nullable = false, new MetadataBuilder().putString("sourcecolumn", "o").build), + StructField("e", IntegerType, nullable = false, new MetadataBuilder().putString("sourcecolumn", "e").build), + StructField("f", IntegerType) + )) + val schema = StructType(Seq( + StructField("a", sub, nullable = false, new MetadataBuilder().putString("sourcecolumn", "x").build), + StructField("b", sub, nullable = false, new MetadataBuilder().putString("sourcecolumn", "b").build), + StructField("c", sub) + )) + + val includeIfPredecessorChanged = false + val result = schema.getRenamesInSchema(includeIfPredecessorChanged) + val expected = Map( + "a" -> "x", + "a.d" -> "x.o", + "b.d" -> "b.o", + "c.d" -> "c.o" + ) + + assert(result == expected) + } + + + test("getRenamesInSchema - array") { + val sub = StructType(Seq( + StructField("renamed", IntegerType, nullable = false, new MetadataBuilder().putString("sourcecolumn", "rename source").build), + StructField("same", IntegerType, nullable = false, new MetadataBuilder().putString("sourcecolumn", "same").build), + StructField("f", IntegerType) + )) + val schema = StructType(Seq( + StructField("array1", ArrayType(sub)), + StructField("array2", ArrayType(ArrayType(ArrayType(sub)))), + StructField("array3", ArrayType(IntegerType), nullable = false, new MetadataBuilder().putString("sourcecolumn", "array source").build) + )) + + val includeIfPredecessorChanged = false + val result = schema.getRenamesInSchema(includeIfPredecessorChanged) + val expected = Map( + "array1.renamed" -> "array1.rename source", + "array2.renamed" -> "array2.rename source", + "array3" -> "array source" + ) + + assert(result == expected) + } + + + test("getRenamesInSchema - source column used multiple times") { + val sub = StructType(Seq( + StructField("x", IntegerType, nullable = false, new MetadataBuilder().putString("sourcecolumn", "src").build), + StructField("y", IntegerType, nullable = false, new MetadataBuilder().putString("sourcecolumn", "src").build) + )) + val schema = StructType(Seq( + StructField("a", sub), + StructField("b", IntegerType, nullable = false, new MetadataBuilder().putString("sourcecolumn", "src").build) + )) + + val result = schema.getRenamesInSchema() + val expected = Map( + "a.x" -> "a.src", + "a.y" -> "a.src", + "b" -> "src" + ) + + assert(result == expected) + } + + test("Testing getFirstArrayPath") { + assertResult("f.g")(schema.getFirstArrayPath("f.g.h")) + assertResult("f.g")(schema.getFirstArrayPath("f.g")) + assertResult("")(schema.getFirstArrayPath("z.x.y")) + assertResult("")(schema.getFirstArrayPath("b.c.d.e")) + } + + test("Testing getAllArrayPaths") { + assertResult(Seq("f.g"))(schema.getAllArrayPaths()) + val newSchema = schema("b").dataType.asInstanceOf[StructType] + assertResult(Seq())(newSchema.getAllArrayPaths()) + } + + test("Testing getAllArraysInPath") { + assertResult(Seq("b", "b.c.d"))(nestedSchema.getAllArraysInPath("b.c.d.e")) + } + + test("Testing getFieldNullability") { + assert(schema.getFieldNullability("b.d").get) + assert(schema.getFieldNullability("x.y.z").isEmpty) + } + + test("Test getDeepestCommonArrayPath() for a path without an array") { + val schema = StructType(Seq[StructField]( + StructField("a", + StructType(Seq[StructField]( + StructField("b", StringType)) + )))) + + assert (schema.getDeepestCommonArrayPath(Seq("a", "a.b")).isEmpty) + } + + test("Test getDeepestCommonArrayPath() for a path with a single array at top level") { + val schema = StructType(Seq[StructField]( + StructField("a", ArrayType(StructType(Seq[StructField]( + StructField("b", StringType))) + )))) + + val deepestPath = schema.getDeepestCommonArrayPath(Seq("a", "a.b")) + + assert (deepestPath.nonEmpty) + assert (deepestPath.get == "a") + } + + test("Test getDeepestCommonArrayPath() for a path with a single array at nested level") { + val schema = StructType(Seq[StructField]( + StructField("a", StructType(Seq[StructField]( + StructField("b", ArrayType(StringType)))) + ))) + + val deepestPath = schema.getDeepestCommonArrayPath(Seq("a", "a.b")) + + assert (deepestPath.nonEmpty) + assert (deepestPath.get == "a.b") + } + + test("Test getDeepestCommonArrayPath() for a path with several nested arrays of struct") { + val schema = StructType(Seq[StructField]( + StructField("a", ArrayType(StructType(Seq[StructField]( + StructField("b", StructType(Seq[StructField]( + StructField("c", ArrayType(StructType(Seq[StructField]( + StructField("d", StructType(Seq[StructField]( + StructField("e", StringType)) + ))) + )))) + ))) + ))))) + + val deepestPath = schema.getDeepestCommonArrayPath(Seq("a", "a.b", "a.b.c.d.e", "a.b.c.d")) + + assert (deepestPath.nonEmpty) + assert (deepestPath.get == "a.b.c") + } + + test("Test getDeepestArrayPath() for a path without an array") { + val schema = StructType(Seq[StructField]( + StructField("a", + StructType(Seq[StructField]( + StructField("b", StringType)) + )))) + + assert (schema.getDeepestArrayPath("a.b").isEmpty) + } + + test("Test getDeepestArrayPath() for a path with a single array at top level") { + val schema = StructType(Seq[StructField]( + StructField("a", ArrayType(StructType(Seq[StructField]( + StructField("b", StringType))) + )))) + + val deepestPath = schema.getDeepestArrayPath("a.b") + + assert (deepestPath.nonEmpty) + assert (deepestPath.get == "a") + } + + test("Test getDeepestArrayPath() for a path with a single array at nested level") { + val schema = StructType(Seq[StructField]( + StructField("a", StructType(Seq[StructField]( + StructField("b", ArrayType(StringType)))) + ))) + + val deepestPath = schema.getDeepestArrayPath("a.b") + val deepestPath2 = schema.getDeepestArrayPath("a") + + assert (deepestPath.nonEmpty) + assert (deepestPath.get == "a.b") + assert (deepestPath2.isEmpty) + } + + test("Test getDeepestArrayPath() for a path with several nested arrays of struct") { + val schema = StructType(Seq[StructField]( + StructField("a", ArrayType(StructType(Seq[StructField]( + StructField("b", StructType(Seq[StructField]( + StructField("c", ArrayType(StructType(Seq[StructField]( + StructField("d", StructType(Seq[StructField]( + StructField("e", StringType)) + ))) + )))) + ))) + ))))) + + val deepestPath = schema.getDeepestArrayPath("a.b.c.d.e") + + assert (deepestPath.nonEmpty) + assert (deepestPath.get == "a.b.c") + } + + test("Test isOnlyField()") { + val schema = StructType(Seq[StructField]( + StructField("a", StringType), + StructField("b", StructType(Seq[StructField]( + StructField("e", StringType), + StructField("f", StringType) + ))), + StructField("c", StructType(Seq[StructField]( + StructField("d", StringType) + ))) + )) + + assert(!schema.isOnlyField("a")) + assert(!schema.isOnlyField("b.e")) + assert(!schema.isOnlyField( "b.f")) + assert(schema.isOnlyField("c.d")) + } + + test("Test getStructField on array of arrays") { + assert(arrayOfArraysSchema.getField("a").contains(StructField("a",ArrayType(ArrayType(IntegerType)),nullable = false))) + assert(arrayOfArraysSchema.getField("b").contains(StructField("b",ArrayType(ArrayType(StructType(Seq(StructField("c",StringType,nullable = false))))), nullable = true))) + assert(arrayOfArraysSchema.getField("b.c").contains(StructField("c",StringType,nullable = false))) + assert(arrayOfArraysSchema.getField("b.d").isEmpty) + } + + test("Test fieldExists") { + assert(schema.fieldExists("a")) + assert(schema.fieldExists("b")) + assert(schema.fieldExists("b.c")) + assert(schema.fieldExists("b.d")) + assert(schema.fieldExists("b.d.e")) + assert(schema.fieldExists("f")) + assert(schema.fieldExists("f.g")) + assert(schema.fieldExists("f.g.h")) + assert(!schema.fieldExists("z")) + assert(!schema.fieldExists("x.y.z")) + assert(!schema.fieldExists("f.g.h.a")) + + assert(arrayOfArraysSchema.fieldExists("a")) + assert(arrayOfArraysSchema.fieldExists("b")) + assert(arrayOfArraysSchema.fieldExists("b.c")) + assert(!arrayOfArraysSchema.fieldExists("b.d")) + } + +} From 330c1fa576ba2172586da94c115a1e5c9cb76c40 Mon Sep 17 00:00:00 2001 From: Adrian Olosutean Date: Thu, 13 Jan 2022 15:21:27 +0100 Subject: [PATCH 03/19] #19 tests and small code fix --- .../implicits/StructFieldImplicits.scala | 3 +- .../implicits/ColumnImplicitsTest.scala | 28 ++++++++++++ .../implicits/StructFieldImplicitsTest.scala | 45 +++++++++++++++++++ 3 files changed, 74 insertions(+), 2 deletions(-) create mode 100644 src/test/scala/za/co/absa/spark/commons/implicits/ColumnImplicitsTest.scala create mode 100644 src/test/scala/za/co/absa/spark/commons/implicits/StructFieldImplicitsTest.scala diff --git a/src/main/scala/za/co/absa/spark/commons/implicits/StructFieldImplicits.scala b/src/main/scala/za/co/absa/spark/commons/implicits/StructFieldImplicits.scala index 1610bde6..62063941 100644 --- a/src/main/scala/za/co/absa/spark/commons/implicits/StructFieldImplicits.scala +++ b/src/main/scala/za/co/absa/spark/commons/implicits/StructFieldImplicits.scala @@ -28,7 +28,7 @@ object StructFieldImplicits { def getMetadataChar(key: String): Option[Char] = { val resultString = Try(structField.metadata.getString(key)).toOption resultString.flatMap { s => - if (s.length == 1) { + if (s != null && s.length == 1) { Option(s(0)) } else { None @@ -40,7 +40,6 @@ object StructFieldImplicits { Try(structField.metadata.getString(key).toBoolean).toOption } - def hasMetadataKey(key: String): Boolean = { structField.metadata.contains(key) } diff --git a/src/test/scala/za/co/absa/spark/commons/implicits/ColumnImplicitsTest.scala b/src/test/scala/za/co/absa/spark/commons/implicits/ColumnImplicitsTest.scala new file mode 100644 index 00000000..35b0f537 --- /dev/null +++ b/src/test/scala/za/co/absa/spark/commons/implicits/ColumnImplicitsTest.scala @@ -0,0 +1,28 @@ +package za.co.absa.spark.commons.implicits + +import org.apache.spark.sql.Column +import org.apache.spark.sql.functions.lit +import org.scalatest.funsuite.AnyFunSuite +import za.co.absa.spark.commons.implicits.ColumnImplicits.ColumnEnhancements + +class ColumnImplicitsTest extends AnyFunSuite{ + + private val column: Column = lit("abcdefgh") + + test("zeroBasedSubstr with tartPos") { + assertResult("cdefgh")(column.zeroBasedSubstr(2).expr.eval().toString) + assertResult("gh")(column.zeroBasedSubstr(-2).expr.eval().toString) + assertResult("")(column.zeroBasedSubstr(Int.MaxValue).expr.eval().toString) + assertResult("abcdefgh")(column.zeroBasedSubstr(Int.MinValue).expr.eval().toString) + } + + test("zeroBasedSubstr with tartPos and len") { + assertResult("cde")(column.zeroBasedSubstr(2, 3).expr.eval().toString) + assertResult("gh")(column.zeroBasedSubstr(-2, 7).expr.eval().toString) + assertResult("")(column.zeroBasedSubstr(Int.MaxValue, 1).expr.eval().toString) + assertResult("")(column.zeroBasedSubstr(Int.MaxValue, -3).expr.eval().toString) + assertResult("")(column.zeroBasedSubstr(Int.MinValue,2).expr.eval().toString) + assertResult("")(column.zeroBasedSubstr(Int.MinValue,-3).expr.eval().toString) + } + +} diff --git a/src/test/scala/za/co/absa/spark/commons/implicits/StructFieldImplicitsTest.scala b/src/test/scala/za/co/absa/spark/commons/implicits/StructFieldImplicitsTest.scala new file mode 100644 index 00000000..03e05b37 --- /dev/null +++ b/src/test/scala/za/co/absa/spark/commons/implicits/StructFieldImplicitsTest.scala @@ -0,0 +1,45 @@ +package za.co.absa.spark.commons.implicits + +import org.apache.spark.sql.types.{Metadata, StringType, StructField} +import org.scalatest.funsuite.AnyFunSuite +import za.co.absa.spark.commons.implicits.StructFieldImplicits.StructFieldEnhancements + +class StructFieldImplicitsTest extends AnyFunSuite { + + def fieldWith(value123: String) = { + val value1 = s"""{ \"a\" : ${value123} }""" + StructField("uu", StringType, true, Metadata.fromJson(value1)) + } + + test("getMetadataString") { + assertResult(Some(""))(fieldWith("\"\"").getMetadataString("a")) + assertResult(None)(fieldWith("123").getMetadataString("a")) + assertResult(Some("ffbfg"))(fieldWith("\"ffbfg\"").getMetadataString("a")) + assertResult(Some(null))(fieldWith("null").getMetadataString("a")) + } + + test("getMetadataChar") { + assertResult(None)(fieldWith("\"\"").getMetadataChar("a")) + assertResult(None)(fieldWith("123").getMetadataChar("a")) + assertResult(Some('g'))(fieldWith("\"g\"").getMetadataChar("a")) + assertResult(None)(fieldWith("null").getMetadataChar("a")) + } + + test("getMetadataStringAsBoolean") { + assertResult(None)(fieldWith("\"\"").getMetadataStringAsBoolean("a")) + assertResult(None)(fieldWith("123").getMetadataStringAsBoolean("a")) + assertResult(Some(true))(fieldWith("\"true\"").getMetadataStringAsBoolean("a")) + assertResult(Some(false))(fieldWith("\"false\"").getMetadataStringAsBoolean("a")) + assertResult(None)(fieldWith("false").getMetadataStringAsBoolean("a")) + assertResult(None)(fieldWith("true").getMetadataStringAsBoolean("a")) + assertResult(None)(fieldWith("null").getMetadataStringAsBoolean("a")) + } + + test("hastMetadataKKey") { + assertResult(true)(fieldWith("\"\"").hasMetadataKey("a")) + assertResult(false)(fieldWith("123").hasMetadataKey("b")) + assertResult(true)(fieldWith("\"hvh\"").hasMetadataKey("a")) + assertResult(true)(fieldWith("null").hasMetadataKey("a")) + } + +} From 56bfd970b655e092bc6534f8538a20d9ae7ada2b Mon Sep 17 00:00:00 2001 From: Adrian Olosutean Date: Thu, 13 Jan 2022 15:30:20 +0100 Subject: [PATCH 04/19] #22 fixes --- .../implicits/ColumnImplicitsTest.scala | 16 ++++++ .../implicits/StructFieldImplicitsTest.scala | 50 ++++++------------- 2 files changed, 32 insertions(+), 34 deletions(-) diff --git a/src/test/scala/za/co/absa/spark/commons/implicits/ColumnImplicitsTest.scala b/src/test/scala/za/co/absa/spark/commons/implicits/ColumnImplicitsTest.scala index 35b0f537..6e9a85b9 100644 --- a/src/test/scala/za/co/absa/spark/commons/implicits/ColumnImplicitsTest.scala +++ b/src/test/scala/za/co/absa/spark/commons/implicits/ColumnImplicitsTest.scala @@ -1,3 +1,19 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package za.co.absa.spark.commons.implicits import org.apache.spark.sql.Column diff --git a/src/test/scala/za/co/absa/spark/commons/implicits/StructFieldImplicitsTest.scala b/src/test/scala/za/co/absa/spark/commons/implicits/StructFieldImplicitsTest.scala index fba74392..5257cf61 100644 --- a/src/test/scala/za/co/absa/spark/commons/implicits/StructFieldImplicitsTest.scala +++ b/src/test/scala/za/co/absa/spark/commons/implicits/StructFieldImplicitsTest.scala @@ -1,3 +1,19 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package za.co.absa.spark.commons.implicits import org.apache.spark.sql.types.{IntegerType, Metadata, MetadataBuilder, StringType, StructField} @@ -50,39 +66,5 @@ class StructFieldImplicitsTest extends AnyFunSuite { assertResult(false)(fieldWith("123").hasMetadataKey("b")) assertResult(true)(fieldWith("\"hvh\"").hasMetadataKey("a")) assertResult(true)(fieldWith("null").hasMetadataKey("a")) -/* - * Copyright 2021 ABSA Group Limited - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package za.co.absa.spark.commons.implicits - -import org.apache.spark.sql.types.{IntegerType, MetadataBuilder, StructField} -import org.scalatest.funsuite.AnyFunSuite -import za.co.absa.spark.commons.implicits.StructFieldImplicits.StructFieldEnhancements - -class StructFieldImplicitsTest extends AnyFunSuite { - - private val structFieldNoMetadata = StructField("a", IntegerType) - private val structFieldWithMetadataNotSourceColumn = StructField("a", IntegerType, nullable = false, new MetadataBuilder().putString("meta", "data").build) - private val structFieldWithMetadataSourceColumn = StructField("a", IntegerType, nullable = false, new MetadataBuilder().putString("sourcecolumn", "override_a").build) - - - test("Testing getFieldNameOverriddenByMetadata") { - assertResult("a")(structFieldNoMetadata.getFieldNameOverriddenByMetadata()) - assertResult("a")(structFieldWithMetadataNotSourceColumn.getFieldNameOverriddenByMetadata()) - assertResult("override_a")(structFieldWithMetadataSourceColumn.getFieldNameOverriddenByMetadata()) } - } From 32021cfd8ccc5dc88511773947b2cca536c83af3 Mon Sep 17 00:00:00 2001 From: Adrian Olosutean Date: Thu, 13 Jan 2022 15:31:15 +0100 Subject: [PATCH 05/19] #22 headers --- .../commons/implicits/ColumnImplicitsTest.scala | 16 ++++++++++++++++ .../implicits/StructFieldImplicitsTest.scala | 16 ++++++++++++++++ 2 files changed, 32 insertions(+) diff --git a/src/test/scala/za/co/absa/spark/commons/implicits/ColumnImplicitsTest.scala b/src/test/scala/za/co/absa/spark/commons/implicits/ColumnImplicitsTest.scala index 35b0f537..6e9a85b9 100644 --- a/src/test/scala/za/co/absa/spark/commons/implicits/ColumnImplicitsTest.scala +++ b/src/test/scala/za/co/absa/spark/commons/implicits/ColumnImplicitsTest.scala @@ -1,3 +1,19 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package za.co.absa.spark.commons.implicits import org.apache.spark.sql.Column diff --git a/src/test/scala/za/co/absa/spark/commons/implicits/StructFieldImplicitsTest.scala b/src/test/scala/za/co/absa/spark/commons/implicits/StructFieldImplicitsTest.scala index 03e05b37..53ba1224 100644 --- a/src/test/scala/za/co/absa/spark/commons/implicits/StructFieldImplicitsTest.scala +++ b/src/test/scala/za/co/absa/spark/commons/implicits/StructFieldImplicitsTest.scala @@ -1,3 +1,19 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package za.co.absa.spark.commons.implicits import org.apache.spark.sql.types.{Metadata, StringType, StructField} From 1a5b0043dd77ff6aaa3fd241ab45238617957b20 Mon Sep 17 00:00:00 2001 From: Adrian Olosutean Date: Fri, 14 Jan 2022 10:51:05 +0100 Subject: [PATCH 06/19] #19 feedback --- .../absa/spark/commons/implicits/ColumnImplicitsTest.scala | 5 +++-- .../spark/commons/implicits/StructFieldImplicitsTest.scala | 3 ++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/test/scala/za/co/absa/spark/commons/implicits/ColumnImplicitsTest.scala b/src/test/scala/za/co/absa/spark/commons/implicits/ColumnImplicitsTest.scala index 6e9a85b9..20d1e62b 100644 --- a/src/test/scala/za/co/absa/spark/commons/implicits/ColumnImplicitsTest.scala +++ b/src/test/scala/za/co/absa/spark/commons/implicits/ColumnImplicitsTest.scala @@ -25,18 +25,19 @@ class ColumnImplicitsTest extends AnyFunSuite{ private val column: Column = lit("abcdefgh") - test("zeroBasedSubstr with tartPos") { + test("zeroBasedSubstr with startPos") { assertResult("cdefgh")(column.zeroBasedSubstr(2).expr.eval().toString) assertResult("gh")(column.zeroBasedSubstr(-2).expr.eval().toString) assertResult("")(column.zeroBasedSubstr(Int.MaxValue).expr.eval().toString) assertResult("abcdefgh")(column.zeroBasedSubstr(Int.MinValue).expr.eval().toString) } - test("zeroBasedSubstr with tartPos and len") { + test("zeroBasedSubstr with startPos and len") { assertResult("cde")(column.zeroBasedSubstr(2, 3).expr.eval().toString) assertResult("gh")(column.zeroBasedSubstr(-2, 7).expr.eval().toString) assertResult("")(column.zeroBasedSubstr(Int.MaxValue, 1).expr.eval().toString) assertResult("")(column.zeroBasedSubstr(Int.MaxValue, -3).expr.eval().toString) + assertResult("")(column.zeroBasedSubstr(4, -3).expr.eval().toString) assertResult("")(column.zeroBasedSubstr(Int.MinValue,2).expr.eval().toString) assertResult("")(column.zeroBasedSubstr(Int.MinValue,-3).expr.eval().toString) } diff --git a/src/test/scala/za/co/absa/spark/commons/implicits/StructFieldImplicitsTest.scala b/src/test/scala/za/co/absa/spark/commons/implicits/StructFieldImplicitsTest.scala index 53ba1224..5a325bca 100644 --- a/src/test/scala/za/co/absa/spark/commons/implicits/StructFieldImplicitsTest.scala +++ b/src/test/scala/za/co/absa/spark/commons/implicits/StructFieldImplicitsTest.scala @@ -38,6 +38,7 @@ class StructFieldImplicitsTest extends AnyFunSuite { assertResult(None)(fieldWith("\"\"").getMetadataChar("a")) assertResult(None)(fieldWith("123").getMetadataChar("a")) assertResult(Some('g'))(fieldWith("\"g\"").getMetadataChar("a")) + assertResult(None)(fieldWith("\"abc\"").getMetadataChar("a")) assertResult(None)(fieldWith("null").getMetadataChar("a")) } @@ -51,7 +52,7 @@ class StructFieldImplicitsTest extends AnyFunSuite { assertResult(None)(fieldWith("null").getMetadataStringAsBoolean("a")) } - test("hastMetadataKKey") { + test("hastMetadataKey") { assertResult(true)(fieldWith("\"\"").hasMetadataKey("a")) assertResult(false)(fieldWith("123").hasMetadataKey("b")) assertResult(true)(fieldWith("\"hvh\"").hasMetadataKey("a")) From 2110f30b5f5392f234eb57923b40c1e7e167be94 Mon Sep 17 00:00:00 2001 From: Adrian Olosutean Date: Wed, 19 Jan 2022 14:52:37 +0100 Subject: [PATCH 07/19] #19 renames --- README.md | 9 ++-- .../implicits/StructFieldImplicits.scala | 18 ++++---- .../implicits/StructFieldImplicitsTest.scala | 46 +++++++++---------- 3 files changed, 37 insertions(+), 36 deletions(-) diff --git a/README.md b/README.md index 772f4ad8..f68db30b 100644 --- a/README.md +++ b/README.md @@ -79,30 +79,31 @@ _Column_ provides implicit methods for transforming Spark Columns ### StructFieldImplicits _StructFieldImplicits_ provides implicit methods for working with StructField objects. +Of them, metadata methods are: 1. Gets the metadata String value given a key ```scala - structField.getMetadataString(key) + structField.metadata.getOptString(key) ``` 2. Gets the metadata Char value given a key if the value is a single character String, it returns the char, otherwise None ```scala - structField.getMetadataChar(key) + structField.metadata.getOptChar(key) ``` 3. Gets the metadata boolean value of a given key, given that it can be transformed into boolean ```scala - structField.getMetadataStringAsBoolean(key) + structField.metadata.getStringAsBoolean(key) ``` 4. Checks the structfield if it has the provided key, returns a boolean ```scala - structField.hasMetadataKey(key) + structField.metadata.hasKey(key) ``` # Spark Version Guard diff --git a/src/main/scala/za/co/absa/spark/commons/implicits/StructFieldImplicits.scala b/src/main/scala/za/co/absa/spark/commons/implicits/StructFieldImplicits.scala index 62063941..fdcb601e 100644 --- a/src/main/scala/za/co/absa/spark/commons/implicits/StructFieldImplicits.scala +++ b/src/main/scala/za/co/absa/spark/commons/implicits/StructFieldImplicits.scala @@ -20,13 +20,13 @@ import org.apache.spark.sql.types._ import scala.util.Try object StructFieldImplicits { - implicit class StructFieldEnhancements(val structField: StructField) { - def getMetadataString(key: String): Option[String] = { - Try(structField.metadata.getString(key)).toOption + implicit class StructFieldMetadataEnhancement(val metadata: Metadata) { + def getOptString(key: String): Option[String] = { + Try(metadata.getString(key)).toOption } - def getMetadataChar(key: String): Option[Char] = { - val resultString = Try(structField.metadata.getString(key)).toOption + def getOptChar(key: String): Option[Char] = { + val resultString = Try(metadata.getString(key)).toOption resultString.flatMap { s => if (s != null && s.length == 1) { Option(s(0)) @@ -36,12 +36,12 @@ object StructFieldImplicits { } } - def getMetadataStringAsBoolean(key: String): Option[Boolean] = { - Try(structField.metadata.getString(key).toBoolean).toOption + def getOptStringAsBoolean(key: String): Option[Boolean] = { + Try(metadata.getString(key).toBoolean).toOption } - def hasMetadataKey(key: String): Boolean = { - structField.metadata.contains(key) + def hasKey(key: String): Boolean = { + metadata.contains(key) } } } diff --git a/src/test/scala/za/co/absa/spark/commons/implicits/StructFieldImplicitsTest.scala b/src/test/scala/za/co/absa/spark/commons/implicits/StructFieldImplicitsTest.scala index 5a325bca..2fa0b9ed 100644 --- a/src/test/scala/za/co/absa/spark/commons/implicits/StructFieldImplicitsTest.scala +++ b/src/test/scala/za/co/absa/spark/commons/implicits/StructFieldImplicitsTest.scala @@ -18,7 +18,7 @@ package za.co.absa.spark.commons.implicits import org.apache.spark.sql.types.{Metadata, StringType, StructField} import org.scalatest.funsuite.AnyFunSuite -import za.co.absa.spark.commons.implicits.StructFieldImplicits.StructFieldEnhancements +import za.co.absa.spark.commons.implicits.StructFieldImplicits.StructFieldMetadataEnhancement class StructFieldImplicitsTest extends AnyFunSuite { @@ -28,35 +28,35 @@ class StructFieldImplicitsTest extends AnyFunSuite { } test("getMetadataString") { - assertResult(Some(""))(fieldWith("\"\"").getMetadataString("a")) - assertResult(None)(fieldWith("123").getMetadataString("a")) - assertResult(Some("ffbfg"))(fieldWith("\"ffbfg\"").getMetadataString("a")) - assertResult(Some(null))(fieldWith("null").getMetadataString("a")) + assertResult(Some(""))(fieldWith("\"\"").metadata.getOptString("a")) + assertResult(None)(fieldWith("123").metadata.getOptString("a")) + assertResult(Some("ffbfg"))(fieldWith("\"ffbfg\"").metadata.getOptString("a")) + assertResult(Some(null))(fieldWith("null").metadata.getOptString("a")) } - test("getMetadataChar") { - assertResult(None)(fieldWith("\"\"").getMetadataChar("a")) - assertResult(None)(fieldWith("123").getMetadataChar("a")) - assertResult(Some('g'))(fieldWith("\"g\"").getMetadataChar("a")) - assertResult(None)(fieldWith("\"abc\"").getMetadataChar("a")) - assertResult(None)(fieldWith("null").getMetadataChar("a")) + test("getOptChar") { + assertResult(None)(fieldWith("\"\"").metadata.getOptChar("a")) + assertResult(None)(fieldWith("123").metadata.getOptChar("a")) + assertResult(Some('g'))(fieldWith("\"g\"").metadata.getOptChar("a")) + assertResult(None)(fieldWith("\"abc\"").metadata.getOptChar("a")) + assertResult(None)(fieldWith("null").metadata.getOptChar("a")) } - test("getMetadataStringAsBoolean") { - assertResult(None)(fieldWith("\"\"").getMetadataStringAsBoolean("a")) - assertResult(None)(fieldWith("123").getMetadataStringAsBoolean("a")) - assertResult(Some(true))(fieldWith("\"true\"").getMetadataStringAsBoolean("a")) - assertResult(Some(false))(fieldWith("\"false\"").getMetadataStringAsBoolean("a")) - assertResult(None)(fieldWith("false").getMetadataStringAsBoolean("a")) - assertResult(None)(fieldWith("true").getMetadataStringAsBoolean("a")) - assertResult(None)(fieldWith("null").getMetadataStringAsBoolean("a")) + test("getStringAsBoolean") { + assertResult(None)(fieldWith("\"\"").metadata.getOptStringAsBoolean("a")) + assertResult(None)(fieldWith("123").metadata.getOptStringAsBoolean("a")) + assertResult(Some(true))(fieldWith("\"true\"").metadata.getOptStringAsBoolean("a")) + assertResult(Some(false))(fieldWith("\"false\"").metadata.getOptStringAsBoolean("a")) + assertResult(None)(fieldWith("false").metadata.getOptStringAsBoolean("a")) + assertResult(None)(fieldWith("true").metadata.getOptStringAsBoolean("a")) + assertResult(None)(fieldWith("null").metadata.getOptStringAsBoolean("a")) } test("hastMetadataKey") { - assertResult(true)(fieldWith("\"\"").hasMetadataKey("a")) - assertResult(false)(fieldWith("123").hasMetadataKey("b")) - assertResult(true)(fieldWith("\"hvh\"").hasMetadataKey("a")) - assertResult(true)(fieldWith("null").hasMetadataKey("a")) + assertResult(true)(fieldWith("\"\"").metadata.hasKey("a")) + assertResult(false)(fieldWith("123").metadata.hasKey("b")) + assertResult(true)(fieldWith("\"hvh\"").metadata.hasKey("a")) + assertResult(true)(fieldWith("null").metadata.hasKey("a")) } } From 5f2746a4df60afa59101d53e0cfdc99f29f5bd9a Mon Sep 17 00:00:00 2001 From: Adrian Olosutean Date: Thu, 20 Jan 2022 12:16:17 +0100 Subject: [PATCH 08/19] #22 refactoring --- .../implicits/StructTypeImplicits.scala | 164 +++++++----------- .../implicits/StructTypeImplicitsTest.scala | 98 ++++++++--- 2 files changed, 134 insertions(+), 128 deletions(-) diff --git a/src/main/scala/za/co/absa/spark/commons/implicits/StructTypeImplicits.scala b/src/main/scala/za/co/absa/spark/commons/implicits/StructTypeImplicits.scala index b50f0404..fe2013d0 100644 --- a/src/main/scala/za/co/absa/spark/commons/implicits/StructTypeImplicits.scala +++ b/src/main/scala/za/co/absa/spark/commons/implicits/StructTypeImplicits.scala @@ -112,6 +112,7 @@ object StructTypeImplicits { def fieldExists(path: String): Boolean = { getField(path).nonEmpty } + /** * Returns all renames in the provided schema. * @param includeIfPredecessorChanged if set to true, fields are included even if their name have not changed but @@ -257,21 +258,6 @@ object StructTypeImplicits { } } - /** - * Determine the name of a field - * Will override to "sourcecolumn" in the Metadata if it exists - * - * @param field field to work with - * @return Metadata "sourcecolumn" if it exists or field.name - */ - def getFieldNameOverriddenByMetadata(field: StructField): String = { - if (field.metadata.contains(MetadataKeys.SourceColumn)) { - field.metadata.getString(MetadataKeys.SourceColumn) - } else { - field.name - } - } - /** * Get paths for all array fields in the schema * @@ -285,19 +271,18 @@ object StructTypeImplicits { * Get a closest unique column name * * @param desiredName A prefix to use for the column name - * @param schema A schema to validate if the column already exists * @return A name that can be used as a unique column name */ - def getClosestUniqueName(desiredName: String, schema: StructType): String = { - var exists = true - var columnName = "" - var i = 0 - while (exists) { - columnName = if (i == 0) desiredName else s"${desiredName}_$i" - exists = schema.fields.exists(_.name.compareToIgnoreCase(columnName) == 0) - i += 1 + def getClosestUniqueName(desiredName: String): String = { + def fieldExists(name: String): Boolean = schema.fields.exists(_.name.compareToIgnoreCase(name) == 0) + + if (fieldExists(desiredName)) { + Iterator.from(1) + .map(index => s"${desiredName}_${index}") + .dropWhile(fieldExists).next() + } else { + desiredName } - columnName } /** @@ -310,15 +295,14 @@ object StructTypeImplicits { def structHelper(structField: StructType, path: Seq[String]): Boolean = { val currentField = path.head val isLeaf = path.lengthCompare(1) <= 0 - var isOnlyField = false - structField.fields.foreach(field => + structField.fields.exists(field => if (field.name == currentField) { if (isLeaf) { - isOnlyField = structField.fields.length == 1 + structField.fields.length == 1 } else { field.dataType match { case st: StructType => - isOnlyField = structHelper(st, path.tail) + structHelper(st, path.tail) case _: ArrayType => throw new IllegalArgumentException( s"SchemaUtils.isOnlyField() does not support checking struct fields inside an array") @@ -327,9 +311,8 @@ object StructTypeImplicits { s"Primitive fields cannot have child fields $currentField is a primitive in $column") } } - } + } else false ) - isOnlyField } val path = column.split('.') structHelper(schema, path) @@ -342,34 +325,56 @@ object StructTypeImplicits { * @return true if a field is an array that is not nested in another array */ def isNonNestedArray(fieldPathName: String): Boolean = { - def structHelper(structField: StructType, path: Seq[String]): Boolean = { - val currentField = path.head - val isLeaf = path.lengthCompare(1) <= 0 - var isArray = false - structField.fields.foreach(field => - if (field.name == currentField) { - field.dataType match { - case st: StructType => - if (!isLeaf) { - isArray = structHelper(st, path.tail) - } - case _: ArrayType => - if (isLeaf) { - isArray = true - } - case _ => - if (!isLeaf) { - throw new IllegalArgumentException( - s"Primitive fields cannot have child fields $currentField is a primitive in $fieldPathName") - } - } + val path = fieldPathName.split('.') + structHelper(schema, path, fieldPathName)((_,_) => false) + } + +// @tailrec + private def arrayHelper(fieldPathName: String)(arrayField: ArrayType, path: Seq[String]): Boolean = { + val currentField = path.head + val isLeaf = path.lengthCompare(1) <= 0 + + val applyArrayHelper: (ArrayType, Seq[String]) => Boolean = arrayHelper(fieldPathName) + + arrayField.elementType match { + case st: StructType => + structHelper(st, path.tail, fieldPathName)(applyArrayHelper) + case ar: ArrayType => applyArrayHelper(ar, path) + case _ => + if (!isLeaf) { + throw new IllegalArgumentException( + s"Primitive fields cannot have child fields $currentField is a primitive in $fieldPathName") } - ) - isArray + false } + } - val path = fieldPathName.split('.') - structHelper(schema, path) + private def structHelper(structField: StructType, path: Seq[String], fieldPathName: String) + (leafArrFnc: (ArrayType, Seq[String]) => Boolean): Boolean = { + val currentField = path.head + val isLeaf = path.lengthCompare(1) <= 0 + + structField.fields.exists(field => + if (field.name == currentField) { + field.dataType match { + case st: StructType => + if (!isLeaf) { + structHelper(st, path.tail, fieldPathName)(leafArrFnc) + } else false + case ar: ArrayType => + if (isLeaf) { + true + } else { + leafArrFnc(ar, path) + } + case _ => + if (!isLeaf) { + throw new IllegalArgumentException( + s"Primitive fields cannot have child fields $currentField is a primitive in $fieldPathName") + } + false + } + } else false) } /** @@ -379,53 +384,8 @@ object StructTypeImplicits { * @return true if the specified field is an array */ def isArray(fieldPathName: String): Boolean = { - @tailrec - def arrayHelper(arrayField: ArrayType, path: Seq[String]): Boolean = { - val currentField = path.head - val isLeaf = path.lengthCompare(1) <= 0 - - arrayField.elementType match { - case st: StructType => structHelper(st, path.tail) - case ar: ArrayType => arrayHelper(ar, path) - case _ => - if (!isLeaf) { - throw new IllegalArgumentException( - s"Primitive fields cannot have child fields $currentField is a primitive in $fieldPathName") - } - false - } - } - - def structHelper(structField: StructType, path: Seq[String]): Boolean = { - val currentField = path.head - val isLeaf = path.lengthCompare(1) <= 0 - var isArray = false - structField.fields.foreach(field => - if (field.name == currentField) { - field.dataType match { - case st: StructType => - if (!isLeaf) { - isArray = structHelper(st, path.tail) - } - case ar: ArrayType => - if (isLeaf) { - isArray = true - } else { - isArray = arrayHelper(ar, path) - } - case _ => - if (!isLeaf) { - throw new IllegalArgumentException( - s"Primitive fields cannot have child fields $currentField is a primitive in $fieldPathName") - } - } - } - ) - isArray - } - val path = fieldPathName.split('.') - structHelper(schema, path) + structHelper(schema, path, fieldPathName)(arrayHelper(fieldPathName)) } } diff --git a/src/test/scala/za/co/absa/spark/commons/implicits/StructTypeImplicitsTest.scala b/src/test/scala/za/co/absa/spark/commons/implicits/StructTypeImplicitsTest.scala index afa68008..cd017f41 100644 --- a/src/test/scala/za/co/absa/spark/commons/implicits/StructTypeImplicitsTest.scala +++ b/src/test/scala/za/co/absa/spark/commons/implicits/StructTypeImplicitsTest.scala @@ -16,11 +16,12 @@ package za.co.absa.spark.commons.implicits -import org.apache.spark.sql.types.{ArrayType, ByteType, DateType, DecimalType, IntegerType, LongType, MetadataBuilder, ShortType, StringType, StructField, StructType, TimestampType} +import org.apache.spark.sql.types._ import org.scalatest.funsuite.AnyFunSuite import za.co.absa.spark.commons.implicits.StructTypeImplicits.StructTypeEnhancements +import za.co.absa.spark.commons.test.SparkTestBase -class StructTypeImplicitsTest extends AnyFunSuite { +class StructTypeImplicitsTest extends AnyFunSuite with SparkTestBase{ // scalastyle:off magic.number private val schema = StructType(Seq( @@ -91,7 +92,7 @@ class StructTypeImplicitsTest extends AnyFunSuite { assert(!schema.fieldExists("f.g.h.a")) } - test ("Test isColumnArrayOfStruct") { + test("Test isColumnArrayOfStruct") { assert(!schema.isColumnArrayOfStruct("a")) assert(!schema.isColumnArrayOfStruct("b")) assert(!schema.isColumnArrayOfStruct("b.c")) @@ -135,7 +136,7 @@ class StructTypeImplicitsTest extends AnyFunSuite { val includeIfPredecessorChanged = true val result = schema.getRenamesInSchema(includeIfPredecessorChanged) val expected = Map( - "a" -> "x" , + "a" -> "x", "a.d" -> "x.o", "a.e" -> "x.e", "a.f" -> "x.f", @@ -161,7 +162,7 @@ class StructTypeImplicitsTest extends AnyFunSuite { val includeIfPredecessorChanged = false val result = schema.getRenamesInSchema(includeIfPredecessorChanged) val expected = Map( - "a" -> "x", + "a" -> "x", "a.d" -> "x.o", "b.d" -> "b.o", "c.d" -> "c.o" @@ -188,7 +189,7 @@ class StructTypeImplicitsTest extends AnyFunSuite { val expected = Map( "array1.renamed" -> "array1.rename source", "array2.renamed" -> "array2.rename source", - "array3" -> "array source" + "array3" -> "array source" ) assert(result == expected) @@ -209,7 +210,7 @@ class StructTypeImplicitsTest extends AnyFunSuite { val expected = Map( "a.x" -> "a.src", "a.y" -> "a.src", - "b" -> "src" + "b" -> "src" ) assert(result == expected) @@ -244,7 +245,7 @@ class StructTypeImplicitsTest extends AnyFunSuite { StructField("b", StringType)) )))) - assert (schema.getDeepestCommonArrayPath(Seq("a", "a.b")).isEmpty) + assert(schema.getDeepestCommonArrayPath(Seq("a", "a.b")).isEmpty) } test("Test getDeepestCommonArrayPath() for a path with a single array at top level") { @@ -255,8 +256,8 @@ class StructTypeImplicitsTest extends AnyFunSuite { val deepestPath = schema.getDeepestCommonArrayPath(Seq("a", "a.b")) - assert (deepestPath.nonEmpty) - assert (deepestPath.get == "a") + assert(deepestPath.nonEmpty) + assert(deepestPath.get == "a") } test("Test getDeepestCommonArrayPath() for a path with a single array at nested level") { @@ -267,8 +268,8 @@ class StructTypeImplicitsTest extends AnyFunSuite { val deepestPath = schema.getDeepestCommonArrayPath(Seq("a", "a.b")) - assert (deepestPath.nonEmpty) - assert (deepestPath.get == "a.b") + assert(deepestPath.nonEmpty) + assert(deepestPath.get == "a.b") } test("Test getDeepestCommonArrayPath() for a path with several nested arrays of struct") { @@ -285,8 +286,8 @@ class StructTypeImplicitsTest extends AnyFunSuite { val deepestPath = schema.getDeepestCommonArrayPath(Seq("a", "a.b", "a.b.c.d.e", "a.b.c.d")) - assert (deepestPath.nonEmpty) - assert (deepestPath.get == "a.b.c") + assert(deepestPath.nonEmpty) + assert(deepestPath.get == "a.b.c") } test("Test getDeepestArrayPath() for a path without an array") { @@ -296,7 +297,7 @@ class StructTypeImplicitsTest extends AnyFunSuite { StructField("b", StringType)) )))) - assert (schema.getDeepestArrayPath("a.b").isEmpty) + assert(schema.getDeepestArrayPath("a.b").isEmpty) } test("Test getDeepestArrayPath() for a path with a single array at top level") { @@ -307,8 +308,8 @@ class StructTypeImplicitsTest extends AnyFunSuite { val deepestPath = schema.getDeepestArrayPath("a.b") - assert (deepestPath.nonEmpty) - assert (deepestPath.get == "a") + assert(deepestPath.nonEmpty) + assert(deepestPath.get == "a") } test("Test getDeepestArrayPath() for a path with a single array at nested level") { @@ -320,9 +321,9 @@ class StructTypeImplicitsTest extends AnyFunSuite { val deepestPath = schema.getDeepestArrayPath("a.b") val deepestPath2 = schema.getDeepestArrayPath("a") - assert (deepestPath.nonEmpty) - assert (deepestPath.get == "a.b") - assert (deepestPath2.isEmpty) + assert(deepestPath.nonEmpty) + assert(deepestPath.get == "a.b") + assert(deepestPath2.isEmpty) } test("Test getDeepestArrayPath() for a path with several nested arrays of struct") { @@ -339,8 +340,8 @@ class StructTypeImplicitsTest extends AnyFunSuite { val deepestPath = schema.getDeepestArrayPath("a.b.c.d.e") - assert (deepestPath.nonEmpty) - assert (deepestPath.get == "a.b.c") + assert(deepestPath.nonEmpty) + assert(deepestPath.get == "a.b.c") } test("Test isOnlyField()") { @@ -357,14 +358,14 @@ class StructTypeImplicitsTest extends AnyFunSuite { assert(!schema.isOnlyField("a")) assert(!schema.isOnlyField("b.e")) - assert(!schema.isOnlyField( "b.f")) + assert(!schema.isOnlyField("b.f")) assert(schema.isOnlyField("c.d")) } test("Test getStructField on array of arrays") { - assert(arrayOfArraysSchema.getField("a").contains(StructField("a",ArrayType(ArrayType(IntegerType)),nullable = false))) - assert(arrayOfArraysSchema.getField("b").contains(StructField("b",ArrayType(ArrayType(StructType(Seq(StructField("c",StringType,nullable = false))))), nullable = true))) - assert(arrayOfArraysSchema.getField("b.c").contains(StructField("c",StringType,nullable = false))) + assert(arrayOfArraysSchema.getField("a").contains(StructField("a", ArrayType(ArrayType(IntegerType)), nullable = false))) + assert(arrayOfArraysSchema.getField("b").contains(StructField("b", ArrayType(ArrayType(StructType(Seq(StructField("c", StringType, nullable = false))))), nullable = true))) + assert(arrayOfArraysSchema.getField("b.c").contains(StructField("c", StringType, nullable = false))) assert(arrayOfArraysSchema.getField("b.d").isEmpty) } @@ -387,4 +388,49 @@ class StructTypeImplicitsTest extends AnyFunSuite { assert(!arrayOfArraysSchema.fieldExists("b.d")) } + test("Test getClosestUniqueName() is working properly") { + val schema = StructType(Seq[StructField]( + StructField("value", StringType), + StructField("value_1", StringType), + StructField("value_2", StringType) + )) + + // A column name that does not exist + val name1 = schema.getClosestUniqueName("v") + // A column that exists + val name2 = schema.getClosestUniqueName("value") + + assert(name1 == "v") + assert(name2 == "value_3") + } + + val sample = + """{"id":1,"legs":[{"legid":100,"conditions":[{"checks":[{"checkNums":["1","2","3b","4","5c","6"]}],"amount":100}]}]}""" :: + """{"id":2,"legs":[{"legid":200,"conditions":[{"checks":[{"checkNums":["8","9","10b","11","12c","13"]}],"amount":200}]}]}""" :: + """{"id":3,"legs":[{"legid":300,"conditions":[{"checks":[],"amount": 300}]}]}""" :: + """{"id":4,"legs":[{"legid":400,"conditions":[{"checks":null,"amount": 400}]}]}""" :: + """{"id":5,"legs":[{"legid":500,"conditions":[]}]}""" :: + """{"id":6,"legs":[]}""" :: + """{"id":7}""" :: Nil + import spark.implicits._ + val df = spark.read.json(sample.toDS) + + test("Test isNonNestedArray") { + assert(df.schema.isNonNestedArray("legs")) + assert(!df.schema.isNonNestedArray("legs.conditions")) + assert(!df.schema.isNonNestedArray("legs.conditions.checks")) + assert(!df.schema.isNonNestedArray("legs.conditions.checks.checkNums")) + assert(!df.schema.isNonNestedArray("id")) + assert(!df.schema.isNonNestedArray("legs.legid")) + } + + test("Test isNonArray") { + assert(df.schema.isArray("legs")) + assert(df.schema.isArray("legs.conditions")) + assert(df.schema.isArray("legs.conditions.checks")) + assert(df.schema.isArray("legs.conditions.checks.checkNums")) + assert(!df.schema.isArray("id")) + assert(!df.schema.isArray("legs.legid")) + } + } From e58d3aaf26f573806c945a7c556cc44934f94cfa Mon Sep 17 00:00:00 2001 From: Adrian Olosutean Date: Thu, 20 Jan 2022 14:22:25 +0100 Subject: [PATCH 09/19] #19 other feedback --- README.md | 4 ++-- .../commons/implicits/StructFieldImplicitsTest.scala | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index f68db30b..f5fbe5c2 100644 --- a/README.md +++ b/README.md @@ -54,7 +54,7 @@ select to order and positionally filter columns of a DataFrame ### ColumnImplicits -_Column_ provides implicit methods for transforming Spark Columns +_ColumnImplicits_ provide implicit methods for transforming Spark Columns 1. Transforms the column into a booleaan column, checking if values are negative or positive infinity @@ -81,7 +81,7 @@ _Column_ provides implicit methods for transforming Spark Columns _StructFieldImplicits_ provides implicit methods for working with StructField objects. Of them, metadata methods are: -1. Gets the metadata String value given a key +1. Gets the metadata Option[String] value given a key ```scala structField.metadata.getOptString(key) diff --git a/src/test/scala/za/co/absa/spark/commons/implicits/StructFieldImplicitsTest.scala b/src/test/scala/za/co/absa/spark/commons/implicits/StructFieldImplicitsTest.scala index 2fa0b9ed..055afa9c 100644 --- a/src/test/scala/za/co/absa/spark/commons/implicits/StructFieldImplicitsTest.scala +++ b/src/test/scala/za/co/absa/spark/commons/implicits/StructFieldImplicitsTest.scala @@ -23,11 +23,11 @@ import za.co.absa.spark.commons.implicits.StructFieldImplicits.StructFieldMetada class StructFieldImplicitsTest extends AnyFunSuite { def fieldWith(value123: String) = { - val value1 = s"""{ \"a\" : ${value123} }""" + val value1 = s"""{ "a" : ${value123} }""" StructField("uu", StringType, true, Metadata.fromJson(value1)) } - test("getMetadataString") { + test("getOptString") { assertResult(Some(""))(fieldWith("\"\"").metadata.getOptString("a")) assertResult(None)(fieldWith("123").metadata.getOptString("a")) assertResult(Some("ffbfg"))(fieldWith("\"ffbfg\"").metadata.getOptString("a")) @@ -42,7 +42,7 @@ class StructFieldImplicitsTest extends AnyFunSuite { assertResult(None)(fieldWith("null").metadata.getOptChar("a")) } - test("getStringAsBoolean") { + test("getOptStringAsBoolean") { assertResult(None)(fieldWith("\"\"").metadata.getOptStringAsBoolean("a")) assertResult(None)(fieldWith("123").metadata.getOptStringAsBoolean("a")) assertResult(Some(true))(fieldWith("\"true\"").metadata.getOptStringAsBoolean("a")) @@ -52,7 +52,7 @@ class StructFieldImplicitsTest extends AnyFunSuite { assertResult(None)(fieldWith("null").metadata.getOptStringAsBoolean("a")) } - test("hastMetadataKey") { + test("hasKey") { assertResult(true)(fieldWith("\"\"").metadata.hasKey("a")) assertResult(false)(fieldWith("123").metadata.hasKey("b")) assertResult(true)(fieldWith("\"hvh\"").metadata.hasKey("a")) From ade77e2df83fa24f36898681a08eaa60c5439ed4 Mon Sep 17 00:00:00 2001 From: Adrian Olosutean Date: Fri, 21 Jan 2022 13:12:50 +0100 Subject: [PATCH 10/19] #22 refactoring --- .../implicits/StructFieldImplicits.scala | 1 - .../implicits/StructTypeImplicits.scala | 107 ++++++++---------- 2 files changed, 46 insertions(+), 62 deletions(-) diff --git a/src/main/scala/za/co/absa/spark/commons/implicits/StructFieldImplicits.scala b/src/main/scala/za/co/absa/spark/commons/implicits/StructFieldImplicits.scala index 9f741dc3..dc9c7ac5 100644 --- a/src/main/scala/za/co/absa/spark/commons/implicits/StructFieldImplicits.scala +++ b/src/main/scala/za/co/absa/spark/commons/implicits/StructFieldImplicits.scala @@ -48,7 +48,6 @@ object StructFieldImplicits { /** * Determine the name of a field - * Will override to "sourcecolumn" in the Metadata if it exists * * @return Metadata "sourcecolumn" if it exists or field.name */ diff --git a/src/main/scala/za/co/absa/spark/commons/implicits/StructTypeImplicits.scala b/src/main/scala/za/co/absa/spark/commons/implicits/StructTypeImplicits.scala index fe2013d0..5017ab9a 100644 --- a/src/main/scala/za/co/absa/spark/commons/implicits/StructTypeImplicits.scala +++ b/src/main/scala/za/co/absa/spark/commons/implicits/StructTypeImplicits.scala @@ -292,30 +292,8 @@ object StructTypeImplicits { * @return true if the column is the only column in a struct */ def isOnlyField(column: String): Boolean = { - def structHelper(structField: StructType, path: Seq[String]): Boolean = { - val currentField = path.head - val isLeaf = path.lengthCompare(1) <= 0 - structField.fields.exists(field => - if (field.name == currentField) { - if (isLeaf) { - structField.fields.length == 1 - } else { - field.dataType match { - case st: StructType => - structHelper(st, path.tail) - case _: ArrayType => - throw new IllegalArgumentException( - s"SchemaUtils.isOnlyField() does not support checking struct fields inside an array") - case _ => - throw new IllegalArgumentException( - s"Primitive fields cannot have child fields $currentField is a primitive in $column") - } - } - } else false - ) - } val path = column.split('.') - structHelper(schema, path) + structArrayHelper(schema, path, column)(applyArrayHelper = false, field => field.fields.length == 1) } /** @@ -326,20 +304,30 @@ object StructTypeImplicits { */ def isNonNestedArray(fieldPathName: String): Boolean = { val path = fieldPathName.split('.') - structHelper(schema, path, fieldPathName)((_,_) => false) + structArrayHelper(schema, path, fieldPathName)(applyArrayHelper = false) + } + + /** + * Checks if a field is an array + * + * @param fieldPathName A field to check + * @return true if the specified field is an array + */ + def isArray(fieldPathName: String): Boolean = { + val path = fieldPathName.split('.') + structArrayHelper(schema, path, fieldPathName)(applyArrayHelper = true) } -// @tailrec - private def arrayHelper(fieldPathName: String)(arrayField: ArrayType, path: Seq[String]): Boolean = { + @tailrec + private def arrayHelper(fieldPathName: String, arrayField: ArrayType, path: Seq[String]) + (applyArrayHelper: Boolean, conditionLeafSh: StructType => Boolean): Boolean = { val currentField = path.head val isLeaf = path.lengthCompare(1) <= 0 - val applyArrayHelper: (ArrayType, Seq[String]) => Boolean = arrayHelper(fieldPathName) - arrayField.elementType match { case st: StructType => - structHelper(st, path.tail, fieldPathName)(applyArrayHelper) - case ar: ArrayType => applyArrayHelper(ar, path) + structArrayHelper(st, path.tail, fieldPathName)(applyArrayHelper, conditionLeafSh) + case ar: ArrayType => arrayHelper(fieldPathName, ar, path) (applyArrayHelper, conditionLeafSh) case _ => if (!isLeaf) { throw new IllegalArgumentException( @@ -349,43 +337,40 @@ object StructTypeImplicits { } } - private def structHelper(structField: StructType, path: Seq[String], fieldPathName: String) - (leafArrFnc: (ArrayType, Seq[String]) => Boolean): Boolean = { + private def structArrayHelper(structField: StructType, path: Seq[String], fieldPathName: String) + (applyArrayHelper: Boolean, + conditionLeafSh: StructType => Boolean = _ => false): Boolean = { val currentField = path.head val isLeaf = path.lengthCompare(1) <= 0 structField.fields.exists(field => if (field.name == currentField) { - field.dataType match { - case st: StructType => - if (!isLeaf) { - structHelper(st, path.tail, fieldPathName)(leafArrFnc) - } else false - case ar: ArrayType => - if (isLeaf) { - true - } else { - leafArrFnc(ar, path) - } - case _ => - if (!isLeaf) { - throw new IllegalArgumentException( - s"Primitive fields cannot have child fields $currentField is a primitive in $fieldPathName") - } - false + if (isLeaf && conditionLeafSh(structField)) { + true + } else { + field.dataType match { + case st: StructType => + if (!isLeaf) { + structArrayHelper(st, path.tail, fieldPathName)(applyArrayHelper, conditionLeafSh) + } else false + case ar: ArrayType => + if (isLeaf) { + true + } else if (applyArrayHelper) { + arrayHelper(fieldPathName, ar, path)(applyArrayHelper, conditionLeafSh) + } else { + false + } + case _ => + if (!isLeaf) { + throw new IllegalArgumentException( + s"Primitive fields cannot have child fields $currentField is a primitive in $fieldPathName") + } + false + } } - } else false) - } - - /** - * Checks if a field is an array - * - * @param fieldPathName A field to check - * @return true if the specified field is an array - */ - def isArray(fieldPathName: String): Boolean = { - val path = fieldPathName.split('.') - structHelper(schema, path, fieldPathName)(arrayHelper(fieldPathName)) + } else false + ) } } From 2517a20509e7925f94d68ab7331de0bad3d93e27 Mon Sep 17 00:00:00 2001 From: Adrian Olosutean Date: Mon, 24 Jan 2022 17:55:27 +0100 Subject: [PATCH 11/19] #22 feedback --- .../implicits/StructTypeImplicits.scala | 80 ++++++++----------- 1 file changed, 32 insertions(+), 48 deletions(-) diff --git a/src/main/scala/za/co/absa/spark/commons/implicits/StructTypeImplicits.scala b/src/main/scala/za/co/absa/spark/commons/implicits/StructTypeImplicits.scala index 5017ab9a..fe18c804 100644 --- a/src/main/scala/za/co/absa/spark/commons/implicits/StructTypeImplicits.scala +++ b/src/main/scala/za/co/absa/spark/commons/implicits/StructTypeImplicits.scala @@ -165,9 +165,9 @@ object StructTypeImplicits { /** * Get first array column's path out of complete path. * - * E.g if the path argument is "a.b.c.d.e" where b and d are arrays, "a.b" will be returned. + * E.g if the path argument is "a.b.c.d.e" where b and d are arrays, "a.b" will be returned. * - * @param path The path to the attribute + * @param path The path to the attribute * @return The path of the first array field or "" if none were found */ def getFirstArrayPath(path: String): String = { @@ -191,9 +191,9 @@ object StructTypeImplicits { /** * Get all array columns' paths out of complete path. * - * E.g. if the path argument is "a.b.c.d.e" where b and d are arrays, "a.b" and "a.b.c.d" will be returned. + * E.g. if the path argument is "a.b.c.d.e" where b and d are arrays, "a.b" and "a.b.c.d" will be returned. * - * @param path The path to the attribute + * @param path The path to the attribute * @return Seq of dot-separated paths for all array fields in the provided path */ def getAllArraysInPath(path: String): Seq[String] = { @@ -293,7 +293,7 @@ object StructTypeImplicits { */ def isOnlyField(column: String): Boolean = { val path = column.split('.') - structArrayHelper(schema, path, column)(applyArrayHelper = false, field => field.fields.length == 1) + evaluateConditionsForField(schema, path, column, applyArrayHelper = false, field => field.fields.length == 1) } /** @@ -304,7 +304,7 @@ object StructTypeImplicits { */ def isNonNestedArray(fieldPathName: String): Boolean = { val path = fieldPathName.split('.') - structArrayHelper(schema, path, fieldPathName)(applyArrayHelper = false) + evaluateConditionsForField(schema, path, fieldPathName, applyArrayHelper = false) } /** @@ -315,60 +315,44 @@ object StructTypeImplicits { */ def isArray(fieldPathName: String): Boolean = { val path = fieldPathName.split('.') - structArrayHelper(schema, path, fieldPathName)(applyArrayHelper = true) + evaluateConditionsForField(schema, path, fieldPathName, applyArrayHelper = true) } - @tailrec - private def arrayHelper(fieldPathName: String, arrayField: ArrayType, path: Seq[String]) - (applyArrayHelper: Boolean, conditionLeafSh: StructType => Boolean): Boolean = { + private def evaluateConditionsForField(structField: StructType, path: Seq[String], fieldPathName: String, + applyArrayHelper: Boolean, + conditionLeafSh: StructType => Boolean = _ => false): Boolean = { val currentField = path.head val isLeaf = path.lengthCompare(1) <= 0 - arrayField.elementType match { - case st: StructType => - structArrayHelper(st, path.tail, fieldPathName)(applyArrayHelper, conditionLeafSh) - case ar: ArrayType => arrayHelper(fieldPathName, ar, path) (applyArrayHelper, conditionLeafSh) - case _ => - if (!isLeaf) { - throw new IllegalArgumentException( - s"Primitive fields cannot have child fields $currentField is a primitive in $fieldPathName") - } - false + @tailrec + def arrayHelper(fieldPathName: String, arrayField: ArrayType, path: Seq[String]): Boolean = { + arrayField.elementType match { + case st: StructType => + evaluateConditionsForField(st, path.tail, fieldPathName, applyArrayHelper, conditionLeafSh) + case ar: ArrayType => arrayHelper(fieldPathName, ar, path) + case _ => + if (!isLeaf) { + throw new IllegalArgumentException( + s"Primitive fields cannot have child fields $currentField is a primitive in $fieldPathName") + } + false + } } - } - - private def structArrayHelper(structField: StructType, path: Seq[String], fieldPathName: String) - (applyArrayHelper: Boolean, - conditionLeafSh: StructType => Boolean = _ => false): Boolean = { - val currentField = path.head - val isLeaf = path.lengthCompare(1) <= 0 structField.fields.exists(field => if (field.name == currentField) { if (isLeaf && conditionLeafSh(structField)) { true - } else { - field.dataType match { - case st: StructType => - if (!isLeaf) { - structArrayHelper(st, path.tail, fieldPathName)(applyArrayHelper, conditionLeafSh) - } else false - case ar: ArrayType => - if (isLeaf) { - true - } else if (applyArrayHelper) { - arrayHelper(fieldPathName, ar, path)(applyArrayHelper, conditionLeafSh) - } else { - false - } - case _ => - if (!isLeaf) { - throw new IllegalArgumentException( - s"Primitive fields cannot have child fields $currentField is a primitive in $fieldPathName") - } - false + } else + (field.dataType, isLeaf) match { + case (st: StructType, false) => + evaluateConditionsForField(st, path.tail, fieldPathName, applyArrayHelper, conditionLeafSh) + case (_: ArrayType, true) => true + case (ar: ArrayType, false) if applyArrayHelper => arrayHelper(fieldPathName, ar, path) + case (_: ArrayType, false) => false + case (_, false) => throw new IllegalArgumentException(s"Primitive fields cannot have child fields $currentField is a primitive in $fieldPathName") + case (_, _) => false } - } } else false ) } From 8bde069d924fc8c90ddd6ebc50c735cbfc5e615f Mon Sep 17 00:00:00 2001 From: Adrian Olosutean Date: Tue, 25 Jan 2022 15:38:15 +0100 Subject: [PATCH 12/19] #22 merge --- .../implicits/StructFieldImplicits.scala | 14 -------------- .../commons/implicits/StructTypeImplicits.scala | 5 +++-- .../implicits/StructFieldImplicitsTest.scala | 17 ----------------- .../commons/{ => schema}/SchemaUtilSuite.scala | 4 ++-- 4 files changed, 5 insertions(+), 35 deletions(-) rename src/test/scala/za/co/absa/spark/commons/{ => schema}/SchemaUtilSuite.scala (94%) diff --git a/src/main/scala/za/co/absa/spark/commons/implicits/StructFieldImplicits.scala b/src/main/scala/za/co/absa/spark/commons/implicits/StructFieldImplicits.scala index 521dada0..ad3306af 100644 --- a/src/main/scala/za/co/absa/spark/commons/implicits/StructFieldImplicits.scala +++ b/src/main/scala/za/co/absa/spark/commons/implicits/StructFieldImplicits.scala @@ -17,7 +17,6 @@ package za.co.absa.spark.commons.implicits import org.apache.spark.sql.types._ -import za.co.absa.spark.commons.schema.MetadataKeys import scala.util.Try @@ -45,18 +44,5 @@ object StructFieldImplicits { def hasKey(key: String): Boolean = { metadata.contains(key) } - - /** - * Determine the name of a field - * - * @return Metadata "sourcecolumn" if it exists or field.name - */ - def getFieldNameOverriddenByMetadata(): String = { - if (structField.metadata.contains(MetadataKeys.SourceColumn)) { - structField.metadata.getString(MetadataKeys.SourceColumn) - } else { - structField.name - } - } } } diff --git a/src/main/scala/za/co/absa/spark/commons/implicits/StructTypeImplicits.scala b/src/main/scala/za/co/absa/spark/commons/implicits/StructTypeImplicits.scala index fe18c804..55989782 100644 --- a/src/main/scala/za/co/absa/spark/commons/implicits/StructTypeImplicits.scala +++ b/src/main/scala/za/co/absa/spark/commons/implicits/StructTypeImplicits.scala @@ -17,6 +17,7 @@ package za.co.absa.spark.commons.implicits import org.apache.spark.sql.types.{ArrayType, DataType, StructField, StructType} +import za.co.absa.spark.commons.implicits.StructFieldImplicits.StructFieldMetadataEnhancement import za.co.absa.spark.commons.schema.MetadataKeys import za.co.absa.spark.commons.schema.SchemaUtils.{appendPath, getAllArraySubPaths, isCommonSubPath} @@ -127,11 +128,11 @@ object StructTypeImplicits { struct: StructType, renamesAcc: Map[String, String], predecessorChanged: Boolean): Map[String, String] = { - import za.co.absa.spark.commons.implicits.StructFieldImplicits.StructFieldEnhancements + import za.co.absa.spark.commons.implicits.StructFieldImplicits.StructFieldMetadataEnhancement struct.fields.foldLeft(renamesAcc) { (renamesSoFar, field) => val fieldFullName = appendPath(path, field.name) - val fieldSourceName = field.getMetadataString(MetadataKeys.SourceColumn).getOrElse(field.name) + val fieldSourceName = field.metadata.getOptString(MetadataKeys.SourceColumn).getOrElse(field.name) val fieldFullSourceName = appendPath(sourcePath, fieldSourceName) val (renames, renameOnPath) = if ((fieldSourceName != field.name) || (predecessorChanged && includeIfPredecessorChanged)) { diff --git a/src/test/scala/za/co/absa/spark/commons/implicits/StructFieldImplicitsTest.scala b/src/test/scala/za/co/absa/spark/commons/implicits/StructFieldImplicitsTest.scala index b0cf4f37..8665be91 100644 --- a/src/test/scala/za/co/absa/spark/commons/implicits/StructFieldImplicitsTest.scala +++ b/src/test/scala/za/co/absa/spark/commons/implicits/StructFieldImplicitsTest.scala @@ -27,10 +27,6 @@ class StructFieldImplicitsTest extends AnyFunSuite { StructField("uu", StringType, true, Metadata.fromJson(value1)) } - private val structFieldNoMetadata = StructField("a", IntegerType) - private val structFieldWithMetadataNotSourceColumn = StructField("a", IntegerType, nullable = false, new MetadataBuilder().putString("meta", "data").build) - private val structFieldWithMetadataSourceColumn = StructField("a", IntegerType, nullable = false, new MetadataBuilder().putString("sourcecolumn", "override_a").build) - test("getOptString") { assertResult(Some(""))(fieldWith("\"\"").metadata.getOptString("a")) assertResult(None)(fieldWith("123").metadata.getOptString("a")) @@ -38,19 +34,6 @@ class StructFieldImplicitsTest extends AnyFunSuite { assertResult(Some(null))(fieldWith("null").metadata.getOptString("a")) } - test("Testing getFieldNameOverriddenByMetadata") { - assertResult("a")(structFieldNoMetadata.getFieldNameOverriddenByMetadata()) - assertResult("a")(structFieldWithMetadataNotSourceColumn.getFieldNameOverriddenByMetadata()) - assertResult("override_a")(structFieldWithMetadataSourceColumn.getFieldNameOverriddenByMetadata()) - } - - test("getMetadataString") { - assertResult(Some(""))(fieldWith("\"\"").getMetadataString("a")) - assertResult(None)(fieldWith("123").getMetadataString("a")) - assertResult(Some("ffbfg"))(fieldWith("\"ffbfg\"").getMetadataString("a")) - assertResult(Some(null))(fieldWith("null").getMetadataString("a")) - } - test("getOptChar") { assertResult(None)(fieldWith("\"\"").metadata.getOptChar("a")) assertResult(None)(fieldWith("123").metadata.getOptChar("a")) diff --git a/src/test/scala/za/co/absa/spark/commons/SchemaUtilSuite.scala b/src/test/scala/za/co/absa/spark/commons/schema/SchemaUtilSuite.scala similarity index 94% rename from src/test/scala/za/co/absa/spark/commons/SchemaUtilSuite.scala rename to src/test/scala/za/co/absa/spark/commons/schema/SchemaUtilSuite.scala index 38fd7b1a..23460362 100644 --- a/src/test/scala/za/co/absa/spark/commons/SchemaUtilSuite.scala +++ b/src/test/scala/za/co/absa/spark/commons/schema/SchemaUtilSuite.scala @@ -14,9 +14,9 @@ * limitations under the License. */ -package za.co.absa.spark.commons +package za.co.absa.spark.commons.schema -import org.apache.spark.sql.types.{ArrayType, ByteType, DateType, DecimalType, IntegerType, LongType, MetadataBuilder, ShortType, StringType, StructField, StructType, TimestampType} +import org.apache.spark.sql.types._ import org.scalatest.funsuite.AnyFunSuite import org.scalatest.matchers.should.Matchers import za.co.absa.spark.commons.schema.SchemaUtils._ From c3599741fd3ed127e9bcf61b454adfc430067d83 Mon Sep 17 00:00:00 2001 From: Adrian Olosutean Date: Wed, 26 Jan 2022 12:30:19 +0100 Subject: [PATCH 13/19] #22 docs + import fixes --- README.md | 97 +++++++++++++++++++ .../implicits/StructFieldImplicits.scala | 16 +++ .../implicits/StructTypeImplicits.scala | 1 - .../implicits/StructFieldImplicitsTest.scala | 2 +- .../commons/schema/SchemaUtilsSpec.scala | 2 - 5 files changed, 114 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 483033ee..353692ec 100644 --- a/README.md +++ b/README.md @@ -79,6 +79,13 @@ _ColumnImplicits_ provide implicit methods for transforming Spark Columns ### StructFieldImplicits _StructFieldImplicits_ provides implicit methods for working with StructField objects. + +1. Determine the name of a field overriden by metadata + + ```scala + structField.getFieldNameOverriddenByMetadata() + ``` + Of them, metadata methods are: 1. Gets the metadata Option[String] value given a key @@ -105,6 +112,96 @@ Of them, metadata methods are: ```scala structField.metadata.hasKey(key) ``` + +### StructTypeImplicits + +_StructTypeImplicits_ provides implicit methods for working with StructType objects. + + +1. Get a field from a text path + + ```scala + structType.getField(path) + ``` +2. Get a type of a field from a text path + + ```scala + structType.getFieldType(path) + ``` +3. Checks if the specified path is an array of structs + + ```scala + structType.isColumnArrayOfStruct(path) + ``` + +4. Get nullability of a field from a text path + + ```scala + structType.getFieldNullability(path) + ``` + +5. Checks if a field specified by a path exists + + ```scala + structType.fieldExists(path) + ``` + +6. Returns all renames in the provided schema + + ```scala + structType.getRenamesInSchema(includeIfPredecessorChanged) + ``` +7. Get first array column's path out of complete path + + ```scala + structType.getFirstArrayPath(path) + ``` +8. Get all array columns' paths out of complete path. + + ```scala + structType.getAllArraysInPath(path) + ``` + +9. For a given list of field paths determines the deepest common array path + + ```scala + structType.getDeepestCommonArrayPath(fieldPaths) + ``` +10. For a field path determines the deepest array path + + ```scala + structType.getDeepestArrayPath(path) + ``` + +11. Get paths for all array fields in the schema + + ```scala + structType.getAllArrayPaths() + ``` + +12. Get a closest unique column name + + ```scala + structType.getClosestUniqueName(desiredName) + ``` + +13. Checks if a field is the only field in a struct + + ```scala + structType.isOnlyField(columnName) + ``` + +14. Checks if a field is an array that is not nested in another array + + ```scala + structType.isNonNestedArray(fieldPathName) + ``` + +15. Checks if a field is an array + + ```scala + structType.isArray(fieldPathName) + ``` # Spark Version Guard diff --git a/src/main/scala/za/co/absa/spark/commons/implicits/StructFieldImplicits.scala b/src/main/scala/za/co/absa/spark/commons/implicits/StructFieldImplicits.scala index ad3306af..277c41d2 100644 --- a/src/main/scala/za/co/absa/spark/commons/implicits/StructFieldImplicits.scala +++ b/src/main/scala/za/co/absa/spark/commons/implicits/StructFieldImplicits.scala @@ -17,6 +17,7 @@ package za.co.absa.spark.commons.implicits import org.apache.spark.sql.types._ +import za.co.absa.spark.commons.schema.MetadataKeys import scala.util.Try @@ -45,4 +46,19 @@ object StructFieldImplicits { metadata.contains(key) } } + + implicit class StructFieldEnhancement(val structField: StructField) { + /** + * Determine the name of a field + * + * @return Metadata "sourcecolumn" if it exists or field.name + */ + def getFieldNameOverriddenByMetadata(): String = { + if (structField.metadata.contains(MetadataKeys.SourceColumn)) { + structField.metadata.getString(MetadataKeys.SourceColumn) + } else { + structField.name + } + } + } } diff --git a/src/main/scala/za/co/absa/spark/commons/implicits/StructTypeImplicits.scala b/src/main/scala/za/co/absa/spark/commons/implicits/StructTypeImplicits.scala index 55989782..acc53e18 100644 --- a/src/main/scala/za/co/absa/spark/commons/implicits/StructTypeImplicits.scala +++ b/src/main/scala/za/co/absa/spark/commons/implicits/StructTypeImplicits.scala @@ -17,7 +17,6 @@ package za.co.absa.spark.commons.implicits import org.apache.spark.sql.types.{ArrayType, DataType, StructField, StructType} -import za.co.absa.spark.commons.implicits.StructFieldImplicits.StructFieldMetadataEnhancement import za.co.absa.spark.commons.schema.MetadataKeys import za.co.absa.spark.commons.schema.SchemaUtils.{appendPath, getAllArraySubPaths, isCommonSubPath} diff --git a/src/test/scala/za/co/absa/spark/commons/implicits/StructFieldImplicitsTest.scala b/src/test/scala/za/co/absa/spark/commons/implicits/StructFieldImplicitsTest.scala index 8665be91..d39a28d0 100644 --- a/src/test/scala/za/co/absa/spark/commons/implicits/StructFieldImplicitsTest.scala +++ b/src/test/scala/za/co/absa/spark/commons/implicits/StructFieldImplicitsTest.scala @@ -16,7 +16,7 @@ package za.co.absa.spark.commons.implicits -import org.apache.spark.sql.types.{IntegerType, Metadata, MetadataBuilder, StringType, StructField} +import org.apache.spark.sql.types.{Metadata, StringType, StructField} import org.scalatest.funsuite.AnyFunSuite import za.co.absa.spark.commons.implicits.StructFieldImplicits.StructFieldMetadataEnhancement diff --git a/src/test/scala/za/co/absa/spark/commons/schema/SchemaUtilsSpec.scala b/src/test/scala/za/co/absa/spark/commons/schema/SchemaUtilsSpec.scala index f8d1446d..771e19ad 100644 --- a/src/test/scala/za/co/absa/spark/commons/schema/SchemaUtilsSpec.scala +++ b/src/test/scala/za/co/absa/spark/commons/schema/SchemaUtilsSpec.scala @@ -17,11 +17,9 @@ package za.co.absa.spark.commons.schema import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.types.{ArrayType, ByteType, DateType, DecimalType, IntegerType, LongType, MetadataBuilder, ShortType, StringType, StructField, StructType, TimestampType} import org.scalatest.BeforeAndAfterAll import org.scalatest.flatspec.AnyFlatSpec import org.scalatest.matchers.should.Matchers -import za.co.absa.spark.commons.schema.SchemaUtils import za.co.absa.spark.commons.test.SparkTestBase class SchemaUtilsSpec extends AnyFlatSpec with Matchers with BeforeAndAfterAll with SparkTestBase { From 965908d56b0613209f6479a5ca851d4e264c5643 Mon Sep 17 00:00:00 2001 From: Adrian Olosutean Date: Fri, 28 Jan 2022 10:27:06 +0100 Subject: [PATCH 14/19] #22 bugfix --- .../commons/implicits/StructTypeImplicits.scala | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/src/main/scala/za/co/absa/spark/commons/implicits/StructTypeImplicits.scala b/src/main/scala/za/co/absa/spark/commons/implicits/StructTypeImplicits.scala index acc53e18..c2094571 100644 --- a/src/main/scala/za/co/absa/spark/commons/implicits/StructTypeImplicits.scala +++ b/src/main/scala/za/co/absa/spark/commons/implicits/StructTypeImplicits.scala @@ -293,7 +293,8 @@ object StructTypeImplicits { */ def isOnlyField(column: String): Boolean = { val path = column.split('.') - evaluateConditionsForField(schema, path, column, applyArrayHelper = false, field => field.fields.length == 1) + evaluateConditionsForField(schema, path, column, applyArrayHelper = false, applyLeafCondition = true, + field => field.fields.length == 1) } /** @@ -319,7 +320,7 @@ object StructTypeImplicits { } private def evaluateConditionsForField(structField: StructType, path: Seq[String], fieldPathName: String, - applyArrayHelper: Boolean, + applyArrayHelper: Boolean, applyLeafCondition: Boolean = false, conditionLeafSh: StructType => Boolean = _ => false): Boolean = { val currentField = path.head val isLeaf = path.lengthCompare(1) <= 0 @@ -328,7 +329,7 @@ object StructTypeImplicits { def arrayHelper(fieldPathName: String, arrayField: ArrayType, path: Seq[String]): Boolean = { arrayField.elementType match { case st: StructType => - evaluateConditionsForField(st, path.tail, fieldPathName, applyArrayHelper, conditionLeafSh) + evaluateConditionsForField(st, path.tail, fieldPathName, applyArrayHelper, applyLeafCondition, conditionLeafSh) case ar: ArrayType => arrayHelper(fieldPathName, ar, path) case _ => if (!isLeaf) { @@ -341,12 +342,10 @@ object StructTypeImplicits { structField.fields.exists(field => if (field.name == currentField) { - if (isLeaf && conditionLeafSh(structField)) { - true - } else - (field.dataType, isLeaf) match { + (field.dataType, isLeaf) match { case (st: StructType, false) => - evaluateConditionsForField(st, path.tail, fieldPathName, applyArrayHelper, conditionLeafSh) + evaluateConditionsForField(st, path.tail, fieldPathName, applyArrayHelper, applyLeafCondition, conditionLeafSh) + case (_, true) if applyLeafCondition => conditionLeafSh(structField) case (_: ArrayType, true) => true case (ar: ArrayType, false) if applyArrayHelper => arrayHelper(fieldPathName, ar, path) case (_: ArrayType, false) => false From 3230c90c93245548a8ed1b300de999a98da2c132 Mon Sep 17 00:00:00 2001 From: Adrian Olosutean Date: Thu, 3 Feb 2022 16:45:44 +0100 Subject: [PATCH 15/19] #22 some feedback --- README.md | 4 +- .../implicits/ArrayTypeImplicits.scala | 29 ++ .../implicits/DataFrameImplicits.scala | 20 ++ .../implicits/StructFieldImplicits.scala | 16 -- .../implicits/StructTypeImplicits.scala | 252 ++++++++--------- .../spark/commons/schema/MetadataKeys.scala | 22 -- .../spark/commons/schema/SchemaUtils.scala | 86 +----- .../implicits/DataFrameImplicitsSuite.scala | 26 +- .../StructTypeImplicitsArrayTest.scala | 189 +++++++++++++ .../implicits/StructTypeImplicitsTest.scala | 258 +----------------- .../commons/schema/SchemaUtilSuite.scala | 94 +++---- .../commons/schema/SchemaUtilsSpec.scala | 22 -- 12 files changed, 423 insertions(+), 595 deletions(-) create mode 100644 src/main/scala/za/co/absa/spark/commons/implicits/ArrayTypeImplicits.scala delete mode 100644 src/main/scala/za/co/absa/spark/commons/schema/MetadataKeys.scala create mode 100644 src/test/scala/za/co/absa/spark/commons/implicits/StructTypeImplicitsArrayTest.scala diff --git a/README.md b/README.md index 353692ec..cfbd3fb2 100644 --- a/README.md +++ b/README.md @@ -194,13 +194,13 @@ _StructTypeImplicits_ provides implicit methods for working with StructType obje 14. Checks if a field is an array that is not nested in another array ```scala - structType.isNonNestedArray(fieldPathName) + structType.isNonNestedArray(path) ``` 15. Checks if a field is an array ```scala - structType.isArray(fieldPathName) + structType.isArray(path) ``` # Spark Version Guard diff --git a/src/main/scala/za/co/absa/spark/commons/implicits/ArrayTypeImplicits.scala b/src/main/scala/za/co/absa/spark/commons/implicits/ArrayTypeImplicits.scala new file mode 100644 index 00000000..e1dae212 --- /dev/null +++ b/src/main/scala/za/co/absa/spark/commons/implicits/ArrayTypeImplicits.scala @@ -0,0 +1,29 @@ +package za.co.absa.spark.commons.implicits + +import org.apache.spark.sql.types.{ArrayType, DataType} + +import scala.annotation.tailrec + +object ArrayTypeImplicits { + + implicit class ArrayTypeEnhancements(arrayType: ArrayType) { + + /** + * For an array of arrays of arrays, ... get the final element type at the bottom of the array + * + * @return A non-array data type at the bottom of array nesting + */ + final def getDeepestArrayType(): Unit = { + @tailrec + def getDeepestArrayTypeHelper(arrayType: ArrayType): DataType = { + arrayType.elementType match { + case a: ArrayType => getDeepestArrayTypeHelper(a) + case b => b + } + } + getDeepestArrayTypeHelper(arrayType) + } + + + } +} diff --git a/src/main/scala/za/co/absa/spark/commons/implicits/DataFrameImplicits.scala b/src/main/scala/za/co/absa/spark/commons/implicits/DataFrameImplicits.scala index 4905b1d6..bc4357c3 100644 --- a/src/main/scala/za/co/absa/spark/commons/implicits/DataFrameImplicits.scala +++ b/src/main/scala/za/co/absa/spark/commons/implicits/DataFrameImplicits.scala @@ -18,7 +18,9 @@ package za.co.absa.spark.commons.implicits import java.io.ByteArrayOutputStream +import org.apache.spark.sql.types.StructType import org.apache.spark.sql.{Column, DataFrame} +import za.co.absa.spark.commons.implicits.StructTypeImplicits.StructTypeEnhancements object DataFrameImplicits { @@ -73,6 +75,24 @@ object DataFrameImplicits { df.withColumn(colName, colExpr) } } + + /** + * Using schema selector returned from [[StructTypeEnhancements.getDataFrameSelector]] aligns the schema of a DataFrame to the selector + * for operations where schema order might be important (e.g. hashing the whole rows and using except) + * + * @param selector model structType for the alignment of df + * @return Returns aligned and filtered schema + */ + def alignSchema(selector: List[Column]): DataFrame = df.select(selector: _*) + + /** + * Using schema selector from [[getDataFrameSelector]] aligns the schema of a DataFrame to the selector for operations + * where schema order might be important (e.g. hashing the whole rows and using except) + * + * @param structType model structType for the alignment of df + * @return Returns aligned and filtered schema + */ + def alignSchema(structType: StructType): DataFrame = alignSchema(structType.getDataFrameSelector()) } } diff --git a/src/main/scala/za/co/absa/spark/commons/implicits/StructFieldImplicits.scala b/src/main/scala/za/co/absa/spark/commons/implicits/StructFieldImplicits.scala index 277c41d2..ad3306af 100644 --- a/src/main/scala/za/co/absa/spark/commons/implicits/StructFieldImplicits.scala +++ b/src/main/scala/za/co/absa/spark/commons/implicits/StructFieldImplicits.scala @@ -17,7 +17,6 @@ package za.co.absa.spark.commons.implicits import org.apache.spark.sql.types._ -import za.co.absa.spark.commons.schema.MetadataKeys import scala.util.Try @@ -46,19 +45,4 @@ object StructFieldImplicits { metadata.contains(key) } } - - implicit class StructFieldEnhancement(val structField: StructField) { - /** - * Determine the name of a field - * - * @return Metadata "sourcecolumn" if it exists or field.name - */ - def getFieldNameOverriddenByMetadata(): String = { - if (structField.metadata.contains(MetadataKeys.SourceColumn)) { - structField.metadata.getString(MetadataKeys.SourceColumn) - } else { - structField.name - } - } - } } diff --git a/src/main/scala/za/co/absa/spark/commons/implicits/StructTypeImplicits.scala b/src/main/scala/za/co/absa/spark/commons/implicits/StructTypeImplicits.scala index c2094571..bd8b2847 100644 --- a/src/main/scala/za/co/absa/spark/commons/implicits/StructTypeImplicits.scala +++ b/src/main/scala/za/co/absa/spark/commons/implicits/StructTypeImplicits.scala @@ -16,9 +16,10 @@ package za.co.absa.spark.commons.implicits +import org.apache.spark.sql.Column +import org.apache.spark.sql.functions.{col, struct} import org.apache.spark.sql.types.{ArrayType, DataType, StructField, StructType} -import za.co.absa.spark.commons.schema.MetadataKeys -import za.co.absa.spark.commons.schema.SchemaUtils.{appendPath, getAllArraySubPaths, isCommonSubPath} +import za.co.absa.spark.commons.schema.SchemaUtils.{getAllArraySubPaths, isCommonSubPath, transform} import scala.annotation.tailrec import scala.util.Try @@ -114,54 +115,122 @@ object StructTypeImplicits { } /** - * Returns all renames in the provided schema. - * @param includeIfPredecessorChanged if set to true, fields are included even if their name have not changed but - * a predecessor's (parent, grandparent etc.) has - * @return the keys of the returned map are the columns' names after renames, the values are the source columns; - * the name are full paths denoted with dot notation + * Get paths for all array fields in the schema + * + * @return Seq of dot separated paths of fields in the schema, which are of type Array */ - def getRenamesInSchema(includeIfPredecessorChanged: Boolean = true): Map[String, String] = { - - def getRenamesRecursively(path: String, - sourcePath: String, - struct: StructType, - renamesAcc: Map[String, String], - predecessorChanged: Boolean): Map[String, String] = { - import za.co.absa.spark.commons.implicits.StructFieldImplicits.StructFieldMetadataEnhancement - - struct.fields.foldLeft(renamesAcc) { (renamesSoFar, field) => - val fieldFullName = appendPath(path, field.name) - val fieldSourceName = field.metadata.getOptString(MetadataKeys.SourceColumn).getOrElse(field.name) - val fieldFullSourceName = appendPath(sourcePath, fieldSourceName) - - val (renames, renameOnPath) = if ((fieldSourceName != field.name) || (predecessorChanged && includeIfPredecessorChanged)) { - (renamesSoFar + (fieldFullName -> fieldFullSourceName), true) - } else { - (renamesSoFar, predecessorChanged) - } + def getAllArrayPaths(): Seq[String] = { + schema.fields.flatMap(f => getAllArraySubPaths("", f.name, f.dataType)).toSeq + } + + /** + * Returns data selector that can be used to align schema of a data frame. You can use [[alignSchema]]. + * + * @return Sorted DF to conform to schema + */ + def getDataFrameSelector(): List[Column] = { + + def processArray(arrType: ArrayType, column: Column, name: String): Column = { + arrType.elementType match { + case arrType: ArrayType => + transform(column, x => processArray(arrType, x, name)).as(name) + case nestedStructType: StructType => + transform(column, x => struct(processStruct(nestedStructType, Some(x)): _*)).as(name) + case _ => column + } + } + def processStruct(curSchema: StructType, parent: Option[Column]): List[Column] = { + curSchema.foldRight(List.empty[Column])((field, acc) => { + val currentCol: Column = parent match { + case Some(x) => x.getField(field.name).as(field.name) + case None => col(field.name) + } field.dataType match { - case st: StructType => getRenamesRecursively(fieldFullName, fieldFullSourceName, st, renames, renameOnPath) - case at: ArrayType => getStructInArray(at.elementType).fold(renames) { item => - getRenamesRecursively(fieldFullName, fieldFullSourceName, item, renames, renameOnPath) - } - case _ => renames + case arrType: ArrayType => processArray(arrType, currentCol, field.name) :: acc + case structType: StructType => struct(processStruct(structType, Some(currentCol)): _*).as(field.name) :: acc + case _ => currentCol :: acc } - } + }) } + processStruct(schema, None) + } + + /** + * Get a closest unique column name + * + * @param desiredName A prefix to use for the column name + * @return A name that can be used as a unique column name + */ + def getClosestUniqueName(desiredName: String): String = { + def fieldExists(name: String): Boolean = schema.fields.exists(_.name.compareToIgnoreCase(name) == 0) + + if (fieldExists(desiredName)) { + Iterator.from(1) + .map(index => s"${desiredName}_${index}") + .dropWhile(fieldExists).next() + } else { + desiredName + } + } + + /** + * Checks if a field is the only field in a struct + * + * @param path A column to check + * @return true if the column is the only column in a struct + */ + def isOnlyField(path: String): Boolean = { + val pathSegments = path.split('.') + evaluateConditionsForField(schema, pathSegments, path, applyArrayHelper = false, applyLeafCondition = true, + field => field.fields.length == 1) + } + + def isOfType[T <: DataType](path: String) = { + val fieldType = getFieldType(path).getOrElse(false) + fieldType.isInstanceOf[T] + } + + protected def evaluateConditionsForField(structField: StructType, path: Seq[String], fieldPathName: String, + applyArrayHelper: Boolean, applyLeafCondition: Boolean = false, + conditionLeafSh: StructType => Boolean = _ => false): Boolean = { + val currentField = path.head + val isLeaf = path.lengthCompare(1) <= 0 + @tailrec - def getStructInArray(dataType: DataType): Option[StructType] = { - dataType match { - case st: StructType => Option(st) - case at: ArrayType => getStructInArray(at.elementType) - case _ => None + def arrayHelper(fieldPathName: String, arrayField: ArrayType, path: Seq[String]): Boolean = { + arrayField.elementType match { + case st: StructType => + evaluateConditionsForField(st, path.tail, fieldPathName, applyArrayHelper, applyLeafCondition, conditionLeafSh) + case ar: ArrayType => arrayHelper(fieldPathName, ar, path) + case _ => + if (!isLeaf) { + throw new IllegalArgumentException( + s"Primitive fields cannot have child fields $currentField is a primitive in $fieldPathName") + } + false } } - getRenamesRecursively("", "", schema, Map.empty, predecessorChanged = false) + structField.fields.exists(field => + if (field.name == currentField) { + (field.dataType, isLeaf) match { + case (st: StructType, false) => + evaluateConditionsForField(st, path.tail, fieldPathName, applyArrayHelper, applyLeafCondition, conditionLeafSh) + case (_, true) if applyLeafCondition => conditionLeafSh(structField) + case (_: ArrayType, true) => true + case (ar: ArrayType, false) if applyArrayHelper => arrayHelper(fieldPathName, ar, path) + case (_: ArrayType, false) => false + case (_, false) => throw new IllegalArgumentException(s"Primitive fields cannot have child fields $currentField is a primitive in $fieldPathName") + case (_, _) => false + } + } else false + ) } + } + implicit class StructTypeEnhancementsArrays(schema: StructType) extends StructTypeEnhancements(schema) { /** * Get first array column's path out of complete path. * @@ -227,11 +296,11 @@ object StructTypeImplicits { * The purpose of the function is to determine the order of explosions to be made before the dataframe can be * joined on a field inside an array. * - * @param fieldPaths A list of paths to analyze + * @param paths A list of paths to analyze * @return Returns a common array path if there is one and None if any of the arrays are on diverging paths */ - def getDeepestCommonArrayPath(fieldPaths: Seq[String]): Option[String] = { - val arrayPaths = fieldPaths.flatMap(path => getAllArraysInPath(path)).distinct + def getDeepestCommonArrayPath(paths: Seq[String]): Option[String] = { + val arrayPaths = paths.flatMap(path => getAllArraysInPath(path)).distinct if (arrayPaths.nonEmpty && isCommonSubPath(arrayPaths: _*)) { Some(arrayPaths.maxBy(_.length)) @@ -245,11 +314,11 @@ object StructTypeImplicits { * * For instance, if given 'a.b.c.d' where b and c are arrays the deepest array is 'a.b.c'. * - * @param fieldPath A path to analyze + * @param path A path to analyze * @return Returns a common array path if there is one and None if any of the arrays are on diverging paths */ - def getDeepestArrayPath(fieldPath: String): Option[String] = { - val arrayPaths = getAllArraysInPath(fieldPath) + def getDeepestArrayPath(path: String): Option[String] = { + val arrayPaths = getAllArraysInPath(path) if (arrayPaths.nonEmpty) { Some(arrayPaths.maxBy(_.length)) @@ -258,102 +327,15 @@ object StructTypeImplicits { } } - /** - * Get paths for all array fields in the schema - * - * @return Seq of dot separated paths of fields in the schema, which are of type Array - */ - def getAllArrayPaths(): Seq[String] = { - schema.fields.flatMap(f => getAllArraySubPaths("", f.name, f.dataType)).toSeq - } - - /** - * Get a closest unique column name - * - * @param desiredName A prefix to use for the column name - * @return A name that can be used as a unique column name - */ - def getClosestUniqueName(desiredName: String): String = { - def fieldExists(name: String): Boolean = schema.fields.exists(_.name.compareToIgnoreCase(name) == 0) - - if (fieldExists(desiredName)) { - Iterator.from(1) - .map(index => s"${desiredName}_${index}") - .dropWhile(fieldExists).next() - } else { - desiredName - } - } - - /** - * Checks if a field is the only field in a struct - * - * @param column A column to check - * @return true if the column is the only column in a struct - */ - def isOnlyField(column: String): Boolean = { - val path = column.split('.') - evaluateConditionsForField(schema, path, column, applyArrayHelper = false, applyLeafCondition = true, - field => field.fields.length == 1) - } - /** * Checks if a field is an array that is not nested in another array * - * @param fieldPathName A field to check + * @param path A field to check * @return true if a field is an array that is not nested in another array */ - def isNonNestedArray(fieldPathName: String): Boolean = { - val path = fieldPathName.split('.') - evaluateConditionsForField(schema, path, fieldPathName, applyArrayHelper = false) - } - - /** - * Checks if a field is an array - * - * @param fieldPathName A field to check - * @return true if the specified field is an array - */ - def isArray(fieldPathName: String): Boolean = { - val path = fieldPathName.split('.') - evaluateConditionsForField(schema, path, fieldPathName, applyArrayHelper = true) - } - - private def evaluateConditionsForField(structField: StructType, path: Seq[String], fieldPathName: String, - applyArrayHelper: Boolean, applyLeafCondition: Boolean = false, - conditionLeafSh: StructType => Boolean = _ => false): Boolean = { - val currentField = path.head - val isLeaf = path.lengthCompare(1) <= 0 - - @tailrec - def arrayHelper(fieldPathName: String, arrayField: ArrayType, path: Seq[String]): Boolean = { - arrayField.elementType match { - case st: StructType => - evaluateConditionsForField(st, path.tail, fieldPathName, applyArrayHelper, applyLeafCondition, conditionLeafSh) - case ar: ArrayType => arrayHelper(fieldPathName, ar, path) - case _ => - if (!isLeaf) { - throw new IllegalArgumentException( - s"Primitive fields cannot have child fields $currentField is a primitive in $fieldPathName") - } - false - } - } - - structField.fields.exists(field => - if (field.name == currentField) { - (field.dataType, isLeaf) match { - case (st: StructType, false) => - evaluateConditionsForField(st, path.tail, fieldPathName, applyArrayHelper, applyLeafCondition, conditionLeafSh) - case (_, true) if applyLeafCondition => conditionLeafSh(structField) - case (_: ArrayType, true) => true - case (ar: ArrayType, false) if applyArrayHelper => arrayHelper(fieldPathName, ar, path) - case (_: ArrayType, false) => false - case (_, false) => throw new IllegalArgumentException(s"Primitive fields cannot have child fields $currentField is a primitive in $fieldPathName") - case (_, _) => false - } - } else false - ) + def isNonNestedArray(path: String): Boolean = { + val pathSegments = path.split('.') + evaluateConditionsForField(schema, pathSegments, path, applyArrayHelper = false) } } diff --git a/src/main/scala/za/co/absa/spark/commons/schema/MetadataKeys.scala b/src/main/scala/za/co/absa/spark/commons/schema/MetadataKeys.scala deleted file mode 100644 index dc6a7aeb..00000000 --- a/src/main/scala/za/co/absa/spark/commons/schema/MetadataKeys.scala +++ /dev/null @@ -1,22 +0,0 @@ -/* - * Copyright 2021 ABSA Group Limited - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package za.co.absa.spark.commons.schema - -object MetadataKeys { - // all - val SourceColumn = "sourcecolumn" -} diff --git a/src/main/scala/za/co/absa/spark/commons/schema/SchemaUtils.scala b/src/main/scala/za/co/absa/spark/commons/schema/SchemaUtils.scala index 9096958c..badd3d44 100644 --- a/src/main/scala/za/co/absa/spark/commons/schema/SchemaUtils.scala +++ b/src/main/scala/za/co/absa/spark/commons/schema/SchemaUtils.scala @@ -16,12 +16,9 @@ package za.co.absa.spark.commons.schema -import org.apache.spark.sql.functions.{col, struct} import org.apache.spark.sql.types._ -import org.apache.spark.sql.{Column, DataFrame} import za.co.absa.spark.commons.adapters.HofsAdapter -import scala.annotation.tailrec object SchemaUtils extends HofsAdapter { @@ -40,18 +37,6 @@ object SchemaUtils extends HofsAdapter { } } - /** - * Converts a fully qualified field name (including its path, e.g. containing fields) to a unique field name without - * dot notation - * - * @param path the fully qualified field name - * @return unique top level field name - */ - def unpath(path: String): String = { - path.replace("_", "__") - .replace('.', '_') - } - /** * Compares 2 array fields of a dataframe schema. * @@ -140,20 +125,6 @@ object SchemaUtils extends HofsAdapter { case _ => false } - /** - * For an array of arrays of arrays, ... get the final element type at the bottom of the array - * - * @param arrayType An array data type from a Spark dataframe schema - * @return A non-array data type at the bottom of array nesting - */ - @tailrec - final def getDeepestArrayType(arrayType: ArrayType): DataType = { - arrayType.elementType match { - case a: ArrayType => getDeepestArrayType(a) - case b => b - } - } - /** * Finds all differences of two StructFields and returns their paths * @@ -175,61 +146,6 @@ object SchemaUtils extends HofsAdapter { } } - /** - * Returns data selector that can be used to align schema of a data frame. You can use [[alignSchema]]. - * - * @param schema Schema that serves as the model of column order - * @return Sorted DF to conform to schema - */ - def getDataFrameSelector(schema: StructType): List[Column] = { - - def processArray(arrType: ArrayType, column: Column, name: String): Column = { - arrType.elementType match { - case arrType: ArrayType => - transform(column, x => processArray(arrType, x, name)).as(name) - case nestedStructType: StructType => - transform(column, x => struct(processStruct(nestedStructType, Some(x)): _*)).as(name) - case _ => column - } - } - - def processStruct(curSchema: StructType, parent: Option[Column]): List[Column] = { - curSchema.foldRight(List.empty[Column])((field, acc) => { - val currentCol: Column = parent match { - case Some(x) => x.getField(field.name).as(field.name) - case None => col(field.name) - } - field.dataType match { - case arrType: ArrayType => processArray(arrType, currentCol, field.name) :: acc - case structType: StructType => struct(processStruct(structType, Some(currentCol)): _*).as(field.name) :: acc - case _ => currentCol :: acc - } - }) - } - - processStruct(schema, None) - } - - /** - * Using schema selector from [[getDataFrameSelector]] aligns the schema of a DataFrame to the selector for operations - * where schema order might be important (e.g. hashing the whole rows and using except) - * - * @param df DataFrame to have it's schema aligned/sorted - * @param structType model structType for the alignment of df - * @return Returns aligned and filtered schema - */ - def alignSchema(df: DataFrame, structType: StructType): DataFrame = df.select(getDataFrameSelector(structType): _*) - - /** - * Using schema selector returned from [[getDataFrameSelector]] aligns the schema of a DataFrame to the selector - * for operations where schema order might be important (e.g. hashing the whole rows and using except) - * - * @param df DataFrame to have it's schema aligned/sorted - * @param selector model structType for the alignment of df - * @return Returns aligned and filtered schema - */ - def alignSchema(df: DataFrame, selector: List[Column]): DataFrame = df.select(selector: _*) - /** * Compares 2 dataframe schemas. * @@ -357,7 +273,7 @@ object SchemaUtils extends HofsAdapter { * @param targetType A type to be casted to * @return true if casting never fails */ - def isCastAlwaysSucceeds(sourceType: DataType, targetType: DataType): Boolean = { + def doesCastAlwaysSucceed(sourceType: DataType, targetType: DataType): Boolean = { (sourceType, targetType) match { case (_: StructType, _) | (_: ArrayType, _) => false case (a, b) if a == b => true diff --git a/src/test/scala/za/co/absa/spark/commons/implicits/DataFrameImplicitsSuite.scala b/src/test/scala/za/co/absa/spark/commons/implicits/DataFrameImplicitsSuite.scala index ac5a5780..5eb595d6 100644 --- a/src/test/scala/za/co/absa/spark/commons/implicits/DataFrameImplicitsSuite.scala +++ b/src/test/scala/za/co/absa/spark/commons/implicits/DataFrameImplicitsSuite.scala @@ -16,7 +16,7 @@ package za.co.absa.spark.commons.implicits -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{AnalysisException, DataFrame} import org.apache.spark.sql.functions.lit import org.scalatest.funsuite.AnyFunSuite import za.co.absa.spark.commons.test.SparkTestBase @@ -54,6 +54,11 @@ class DataFrameImplicitsSuite extends AnyFunSuite with SparkTestBase { "y", "z" ) + + val jsonA = """[{"id":1,"legs":[{"legid":100,"conditions":[{"checks":[{"checkNums":["1","2","3b","4","5c","6"]}],"amount":100}]}], "key" : {"alfa": "1", "beta": {"beta2": "2"}} }]""" + val jsonB = """[{"id":1,"legs":[{"legid":100,"conditions":[{"checks":[{"checkNums":["1","2","3b","4","5c","6"]}],"amount":100,"price":10}]}]}]""" + val jsonC = """[{"legs":[{"legid":100,"conditions":[{"amount":100,"checks":[{"checkNums":["1","2","3b","4","5c","6"]}]}]}],"id":1, "key" : {"beta": {"beta2": "2"}, "alfa": "1"} }]""" + private val inputData = inputDataSeq.toDF(columnName) import za.co.absa.spark.commons.implicits.DataFrameImplicits.DataFrameEnhancements @@ -234,4 +239,23 @@ class DataFrameImplicitsSuite extends AnyFunSuite with SparkTestBase { assert(dfIn.schema.head.name == "value") assert(actualOutput == expectedOutput) } + + test("order schemas for equal schemas") { + val dfA = spark.read.json(Seq(jsonA).toDS) + val dfC = spark.read.json(Seq(jsonC).toDS).select("legs", "id", "key") + + val dfA2Aligned = dfC.alignSchema(dfA.schema) + + assert(dfA.columns.toSeq == dfA2Aligned.columns.toSeq) + assert(dfA.select("key").columns.toSeq == dfA2Aligned.select("key").columns.toSeq) + } + + test("throw an error for DataFrames with different schemas") { + val dfA = spark.read.json(Seq(jsonA).toDS) + val dfB = spark.read.json(Seq(jsonB).toDS) + + assertThrows[AnalysisException]{ + dfA.alignSchema(dfB.schema) + } + } } diff --git a/src/test/scala/za/co/absa/spark/commons/implicits/StructTypeImplicitsArrayTest.scala b/src/test/scala/za/co/absa/spark/commons/implicits/StructTypeImplicitsArrayTest.scala new file mode 100644 index 00000000..688ebc59 --- /dev/null +++ b/src/test/scala/za/co/absa/spark/commons/implicits/StructTypeImplicitsArrayTest.scala @@ -0,0 +1,189 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.spark.commons.implicits + +import org.apache.spark.sql.types.{ArrayType, IntegerType, StringType, StructField, StructType} +import org.scalatest.funsuite.AnyFunSuite +import za.co.absa.spark.commons.implicits.StructTypeImplicits.StructTypeEnhancementsArrays +import za.co.absa.spark.commons.test.SparkTestBase + +class StructTypeImplicitsArrayTest extends AnyFunSuite with SparkTestBase { + + private val schema = StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", StructType(Seq( + StructField("c", IntegerType), + StructField("d", StructType(Seq( + StructField("e", IntegerType))), nullable = true)))), + StructField("f", StructType(Seq( + StructField("g", ArrayType.apply(StructType(Seq( + StructField("h", IntegerType)))))))))) + + private val nestedSchema = StructType(Seq( + StructField("a", IntegerType), + StructField("b", ArrayType(StructType(Seq( + StructField("c", StructType(Seq( + StructField("d", ArrayType(StructType(Seq( + StructField("e", IntegerType)))))))))))))) + + test("Testing getFirstArrayPath") { + assertResult("f.g")(schema.getFirstArrayPath("f.g.h")) + assertResult("f.g")(schema.getFirstArrayPath("f.g")) + assertResult("")(schema.getFirstArrayPath("z.x.y")) + assertResult("")(schema.getFirstArrayPath("b.c.d.e")) + } + + test("Testing getAllArraysInPath") { + assertResult(Seq("b", "b.c.d"))(nestedSchema.getAllArraysInPath("b.c.d.e")) + } + + val sample = + """{"id":1,"legs":[{"legid":100,"conditions":[{"checks":[{"checkNums":["1","2","3b","4","5c","6"]}],"amount":100}]}]}""" :: + """{"id":2,"legs":[{"legid":200,"conditions":[{"checks":[{"checkNums":["8","9","10b","11","12c","13"]}],"amount":200}]}]}""" :: + """{"id":3,"legs":[{"legid":300,"conditions":[{"checks":[],"amount": 300}]}]}""" :: + """{"id":4,"legs":[{"legid":400,"conditions":[{"checks":null,"amount": 400}]}]}""" :: + """{"id":5,"legs":[{"legid":500,"conditions":[]}]}""" :: + """{"id":6,"legs":[]}""" :: + """{"id":7}""" :: Nil + import spark.implicits._ + val df = spark.read.json(sample.toDS) + + test("Test isNonNestedArray") { + assert(df.schema.isNonNestedArray("legs")) + assert(!df.schema.isNonNestedArray("legs.conditions")) + assert(!df.schema.isNonNestedArray("legs.conditions.checks")) + assert(!df.schema.isNonNestedArray("legs.conditions.checks.checkNums")) + assert(!df.schema.isNonNestedArray("id")) + assert(!df.schema.isNonNestedArray("legs.legid")) + } + + test("Test isNonArray") { + assert(df.schema.isOfType[ArrayType]("legs")) + assert(df.schema.isOfType[ArrayType]("legs.conditions")) + assert(df.schema.isOfType[ArrayType]("legs.conditions.checks")) + assert(df.schema.isOfType[ArrayType]("legs.conditions.checks.checkNums")) + assert(!df.schema.isOfType[ArrayType]("id")) + assert(!df.schema.isOfType[ArrayType]("legs.legid")) + } + + test("Test getDeepestCommonArrayPath() for a path without an array") { + val schema = StructType(Seq[StructField]( + StructField("a", + StructType(Seq[StructField]( + StructField("b", StringType)) + )))) + + assert(schema.getDeepestCommonArrayPath(Seq("a", "a.b")).isEmpty) + } + + test("Test getDeepestCommonArrayPath() for a path with a single array at top level") { + val schema = StructType(Seq[StructField]( + StructField("a", ArrayType(StructType(Seq[StructField]( + StructField("b", StringType))) + )))) + + val deepestPath = schema.getDeepestCommonArrayPath(Seq("a", "a.b")) + + assert(deepestPath.nonEmpty) + assert(deepestPath.get == "a") + } + + test("Test getDeepestCommonArrayPath() for a path with a single array at nested level") { + val schema = StructType(Seq[StructField]( + StructField("a", StructType(Seq[StructField]( + StructField("b", ArrayType(StringType)))) + ))) + + val deepestPath = schema.getDeepestCommonArrayPath(Seq("a", "a.b")) + + assert(deepestPath.nonEmpty) + assert(deepestPath.get == "a.b") + } + + test("Test getDeepestCommonArrayPath() for a path with several nested arrays of struct") { + val schema = StructType(Seq[StructField]( + StructField("a", ArrayType(StructType(Seq[StructField]( + StructField("b", StructType(Seq[StructField]( + StructField("c", ArrayType(StructType(Seq[StructField]( + StructField("d", StructType(Seq[StructField]( + StructField("e", StringType)) + ))) + )))) + ))) + ))))) + + val deepestPath = schema.getDeepestCommonArrayPath(Seq("a", "a.b", "a.b.c.d.e", "a.b.c.d")) + + assert(deepestPath.nonEmpty) + assert(deepestPath.get == "a.b.c") + } + + test("Test getDeepestArrayPath() for a path without an array") { + val schema = StructType(Seq[StructField]( + StructField("a", + StructType(Seq[StructField]( + StructField("b", StringType)) + )))) + + assert(schema.getDeepestArrayPath("a.b").isEmpty) + } + + test("Test getDeepestArrayPath() for a path with a single array at top level") { + val schema = StructType(Seq[StructField]( + StructField("a", ArrayType(StructType(Seq[StructField]( + StructField("b", StringType))) + )))) + + val deepestPath = schema.getDeepestArrayPath("a.b") + + assert(deepestPath.nonEmpty) + assert(deepestPath.get == "a") + } + + test("Test getDeepestArrayPath() for a path with a single array at nested level") { + val schema = StructType(Seq[StructField]( + StructField("a", StructType(Seq[StructField]( + StructField("b", ArrayType(StringType)))) + ))) + + val deepestPath = schema.getDeepestArrayPath("a.b") + val deepestPath2 = schema.getDeepestArrayPath("a") + + assert(deepestPath.nonEmpty) + assert(deepestPath.get == "a.b") + assert(deepestPath2.isEmpty) + } + + test("Test getDeepestArrayPath() for a path with several nested arrays of struct") { + val schema = StructType(Seq[StructField]( + StructField("a", ArrayType(StructType(Seq[StructField]( + StructField("b", StructType(Seq[StructField]( + StructField("c", ArrayType(StructType(Seq[StructField]( + StructField("d", StructType(Seq[StructField]( + StructField("e", StringType)) + ))) + )))) + ))) + ))))) + + val deepestPath = schema.getDeepestArrayPath("a.b.c.d.e") + + assert(deepestPath.nonEmpty) + assert(deepestPath.get == "a.b.c") + } + +} diff --git a/src/test/scala/za/co/absa/spark/commons/implicits/StructTypeImplicitsTest.scala b/src/test/scala/za/co/absa/spark/commons/implicits/StructTypeImplicitsTest.scala index cd017f41..58664a77 100644 --- a/src/test/scala/za/co/absa/spark/commons/implicits/StructTypeImplicitsTest.scala +++ b/src/test/scala/za/co/absa/spark/commons/implicits/StructTypeImplicitsTest.scala @@ -21,7 +21,7 @@ import org.scalatest.funsuite.AnyFunSuite import za.co.absa.spark.commons.implicits.StructTypeImplicits.StructTypeEnhancements import za.co.absa.spark.commons.test.SparkTestBase -class StructTypeImplicitsTest extends AnyFunSuite with SparkTestBase{ +class StructTypeImplicitsTest extends AnyFunSuite with SparkTestBase { // scalastyle:off magic.number private val schema = StructType(Seq( @@ -55,6 +55,7 @@ class StructTypeImplicitsTest extends AnyFunSuite with SparkTestBase{ private val structFieldWithMetadataSourceColumn = StructField("a", IntegerType, nullable = false, new MetadataBuilder().putString("sourcecolumn", "override_a").build) test("Testing getFieldType") { + import za.co.absa.spark.commons.implicits.StructTypeImplicits.StructTypeEnhancements val a = schema.getFieldType("a") val b = schema.getFieldType("b") @@ -106,122 +107,6 @@ class StructTypeImplicitsTest extends AnyFunSuite with SparkTestBase{ assert(nestedSchema.isColumnArrayOfStruct("b.c.d")) } - test("getRenamesInSchema - no renames") { - val schema = StructType(Seq( - structFieldNoMetadata, - structFieldWithMetadataNotSourceColumn)) - val result = schema.getRenamesInSchema() - assert(result.isEmpty) - } - - test("getRenamesInSchema - simple rename") { - val schema = StructType(Seq(structFieldWithMetadataSourceColumn)) - val result = schema.getRenamesInSchema() - assert(result == Map("a" -> "override_a")) - - } - - test("getRenamesInSchema - complex with includeIfPredecessorChanged set") { - val sub = StructType(Seq( - StructField("d", IntegerType, nullable = false, new MetadataBuilder().putString("sourcecolumn", "o").build), - StructField("e", IntegerType, nullable = false, new MetadataBuilder().putString("sourcecolumn", "e").build), - StructField("f", IntegerType) - )) - val schema = StructType(Seq( - StructField("a", sub, nullable = false, new MetadataBuilder().putString("sourcecolumn", "x").build), - StructField("b", sub, nullable = false, new MetadataBuilder().putString("sourcecolumn", "b").build), - StructField("c", sub) - )) - - val includeIfPredecessorChanged = true - val result = schema.getRenamesInSchema(includeIfPredecessorChanged) - val expected = Map( - "a" -> "x", - "a.d" -> "x.o", - "a.e" -> "x.e", - "a.f" -> "x.f", - "b.d" -> "b.o", - "c.d" -> "c.o" - ) - - assert(result == expected) - } - - test("getRenamesInSchema - complex with includeIfPredecessorChanged not set") { - val sub = StructType(Seq( - StructField("d", IntegerType, nullable = false, new MetadataBuilder().putString("sourcecolumn", "o").build), - StructField("e", IntegerType, nullable = false, new MetadataBuilder().putString("sourcecolumn", "e").build), - StructField("f", IntegerType) - )) - val schema = StructType(Seq( - StructField("a", sub, nullable = false, new MetadataBuilder().putString("sourcecolumn", "x").build), - StructField("b", sub, nullable = false, new MetadataBuilder().putString("sourcecolumn", "b").build), - StructField("c", sub) - )) - - val includeIfPredecessorChanged = false - val result = schema.getRenamesInSchema(includeIfPredecessorChanged) - val expected = Map( - "a" -> "x", - "a.d" -> "x.o", - "b.d" -> "b.o", - "c.d" -> "c.o" - ) - - assert(result == expected) - } - - - test("getRenamesInSchema - array") { - val sub = StructType(Seq( - StructField("renamed", IntegerType, nullable = false, new MetadataBuilder().putString("sourcecolumn", "rename source").build), - StructField("same", IntegerType, nullable = false, new MetadataBuilder().putString("sourcecolumn", "same").build), - StructField("f", IntegerType) - )) - val schema = StructType(Seq( - StructField("array1", ArrayType(sub)), - StructField("array2", ArrayType(ArrayType(ArrayType(sub)))), - StructField("array3", ArrayType(IntegerType), nullable = false, new MetadataBuilder().putString("sourcecolumn", "array source").build) - )) - - val includeIfPredecessorChanged = false - val result = schema.getRenamesInSchema(includeIfPredecessorChanged) - val expected = Map( - "array1.renamed" -> "array1.rename source", - "array2.renamed" -> "array2.rename source", - "array3" -> "array source" - ) - - assert(result == expected) - } - - - test("getRenamesInSchema - source column used multiple times") { - val sub = StructType(Seq( - StructField("x", IntegerType, nullable = false, new MetadataBuilder().putString("sourcecolumn", "src").build), - StructField("y", IntegerType, nullable = false, new MetadataBuilder().putString("sourcecolumn", "src").build) - )) - val schema = StructType(Seq( - StructField("a", sub), - StructField("b", IntegerType, nullable = false, new MetadataBuilder().putString("sourcecolumn", "src").build) - )) - - val result = schema.getRenamesInSchema() - val expected = Map( - "a.x" -> "a.src", - "a.y" -> "a.src", - "b" -> "src" - ) - - assert(result == expected) - } - - test("Testing getFirstArrayPath") { - assertResult("f.g")(schema.getFirstArrayPath("f.g.h")) - assertResult("f.g")(schema.getFirstArrayPath("f.g")) - assertResult("")(schema.getFirstArrayPath("z.x.y")) - assertResult("")(schema.getFirstArrayPath("b.c.d.e")) - } test("Testing getAllArrayPaths") { assertResult(Seq("f.g"))(schema.getAllArrayPaths()) @@ -229,121 +114,11 @@ class StructTypeImplicitsTest extends AnyFunSuite with SparkTestBase{ assertResult(Seq())(newSchema.getAllArrayPaths()) } - test("Testing getAllArraysInPath") { - assertResult(Seq("b", "b.c.d"))(nestedSchema.getAllArraysInPath("b.c.d.e")) - } - test("Testing getFieldNullability") { assert(schema.getFieldNullability("b.d").get) assert(schema.getFieldNullability("x.y.z").isEmpty) } - test("Test getDeepestCommonArrayPath() for a path without an array") { - val schema = StructType(Seq[StructField]( - StructField("a", - StructType(Seq[StructField]( - StructField("b", StringType)) - )))) - - assert(schema.getDeepestCommonArrayPath(Seq("a", "a.b")).isEmpty) - } - - test("Test getDeepestCommonArrayPath() for a path with a single array at top level") { - val schema = StructType(Seq[StructField]( - StructField("a", ArrayType(StructType(Seq[StructField]( - StructField("b", StringType))) - )))) - - val deepestPath = schema.getDeepestCommonArrayPath(Seq("a", "a.b")) - - assert(deepestPath.nonEmpty) - assert(deepestPath.get == "a") - } - - test("Test getDeepestCommonArrayPath() for a path with a single array at nested level") { - val schema = StructType(Seq[StructField]( - StructField("a", StructType(Seq[StructField]( - StructField("b", ArrayType(StringType)))) - ))) - - val deepestPath = schema.getDeepestCommonArrayPath(Seq("a", "a.b")) - - assert(deepestPath.nonEmpty) - assert(deepestPath.get == "a.b") - } - - test("Test getDeepestCommonArrayPath() for a path with several nested arrays of struct") { - val schema = StructType(Seq[StructField]( - StructField("a", ArrayType(StructType(Seq[StructField]( - StructField("b", StructType(Seq[StructField]( - StructField("c", ArrayType(StructType(Seq[StructField]( - StructField("d", StructType(Seq[StructField]( - StructField("e", StringType)) - ))) - )))) - ))) - ))))) - - val deepestPath = schema.getDeepestCommonArrayPath(Seq("a", "a.b", "a.b.c.d.e", "a.b.c.d")) - - assert(deepestPath.nonEmpty) - assert(deepestPath.get == "a.b.c") - } - - test("Test getDeepestArrayPath() for a path without an array") { - val schema = StructType(Seq[StructField]( - StructField("a", - StructType(Seq[StructField]( - StructField("b", StringType)) - )))) - - assert(schema.getDeepestArrayPath("a.b").isEmpty) - } - - test("Test getDeepestArrayPath() for a path with a single array at top level") { - val schema = StructType(Seq[StructField]( - StructField("a", ArrayType(StructType(Seq[StructField]( - StructField("b", StringType))) - )))) - - val deepestPath = schema.getDeepestArrayPath("a.b") - - assert(deepestPath.nonEmpty) - assert(deepestPath.get == "a") - } - - test("Test getDeepestArrayPath() for a path with a single array at nested level") { - val schema = StructType(Seq[StructField]( - StructField("a", StructType(Seq[StructField]( - StructField("b", ArrayType(StringType)))) - ))) - - val deepestPath = schema.getDeepestArrayPath("a.b") - val deepestPath2 = schema.getDeepestArrayPath("a") - - assert(deepestPath.nonEmpty) - assert(deepestPath.get == "a.b") - assert(deepestPath2.isEmpty) - } - - test("Test getDeepestArrayPath() for a path with several nested arrays of struct") { - val schema = StructType(Seq[StructField]( - StructField("a", ArrayType(StructType(Seq[StructField]( - StructField("b", StructType(Seq[StructField]( - StructField("c", ArrayType(StructType(Seq[StructField]( - StructField("d", StructType(Seq[StructField]( - StructField("e", StringType)) - ))) - )))) - ))) - ))))) - - val deepestPath = schema.getDeepestArrayPath("a.b.c.d.e") - - assert(deepestPath.nonEmpty) - assert(deepestPath.get == "a.b.c") - } - test("Test isOnlyField()") { val schema = StructType(Seq[StructField]( StructField("a", StringType), @@ -404,33 +179,4 @@ class StructTypeImplicitsTest extends AnyFunSuite with SparkTestBase{ assert(name2 == "value_3") } - val sample = - """{"id":1,"legs":[{"legid":100,"conditions":[{"checks":[{"checkNums":["1","2","3b","4","5c","6"]}],"amount":100}]}]}""" :: - """{"id":2,"legs":[{"legid":200,"conditions":[{"checks":[{"checkNums":["8","9","10b","11","12c","13"]}],"amount":200}]}]}""" :: - """{"id":3,"legs":[{"legid":300,"conditions":[{"checks":[],"amount": 300}]}]}""" :: - """{"id":4,"legs":[{"legid":400,"conditions":[{"checks":null,"amount": 400}]}]}""" :: - """{"id":5,"legs":[{"legid":500,"conditions":[]}]}""" :: - """{"id":6,"legs":[]}""" :: - """{"id":7}""" :: Nil - import spark.implicits._ - val df = spark.read.json(sample.toDS) - - test("Test isNonNestedArray") { - assert(df.schema.isNonNestedArray("legs")) - assert(!df.schema.isNonNestedArray("legs.conditions")) - assert(!df.schema.isNonNestedArray("legs.conditions.checks")) - assert(!df.schema.isNonNestedArray("legs.conditions.checks.checkNums")) - assert(!df.schema.isNonNestedArray("id")) - assert(!df.schema.isNonNestedArray("legs.legid")) - } - - test("Test isNonArray") { - assert(df.schema.isArray("legs")) - assert(df.schema.isArray("legs.conditions")) - assert(df.schema.isArray("legs.conditions.checks")) - assert(df.schema.isArray("legs.conditions.checks.checkNums")) - assert(!df.schema.isArray("id")) - assert(!df.schema.isArray("legs.legid")) - } - } diff --git a/src/test/scala/za/co/absa/spark/commons/schema/SchemaUtilSuite.scala b/src/test/scala/za/co/absa/spark/commons/schema/SchemaUtilSuite.scala index 23460362..6ddf76de 100644 --- a/src/test/scala/za/co/absa/spark/commons/schema/SchemaUtilSuite.scala +++ b/src/test/scala/za/co/absa/spark/commons/schema/SchemaUtilSuite.scala @@ -25,47 +25,47 @@ class SchemaUtilSuite extends AnyFunSuite with Matchers { // scalastyle:off magic.number - test ("Test isCastAlwaysSucceeds()") { - assert(!isCastAlwaysSucceeds(StructType(Seq()), StringType)) - assert(!isCastAlwaysSucceeds(ArrayType(StringType), StringType)) - assert(!isCastAlwaysSucceeds(StringType, ByteType)) - assert(!isCastAlwaysSucceeds(StringType, ShortType)) - assert(!isCastAlwaysSucceeds(StringType, IntegerType)) - assert(!isCastAlwaysSucceeds(StringType, LongType)) - assert(!isCastAlwaysSucceeds(StringType, DecimalType(10,10))) - assert(!isCastAlwaysSucceeds(StringType, DateType)) - assert(!isCastAlwaysSucceeds(StringType, TimestampType)) - assert(!isCastAlwaysSucceeds(StructType(Seq()), StructType(Seq()))) - assert(!isCastAlwaysSucceeds(ArrayType(StringType), ArrayType(StringType))) + test ("Test doesCastAlwaysSucceed()") { + assert(!doesCastAlwaysSucceed(StructType(Seq()), StringType)) + assert(!doesCastAlwaysSucceed(ArrayType(StringType), StringType)) + assert(!doesCastAlwaysSucceed(StringType, ByteType)) + assert(!doesCastAlwaysSucceed(StringType, ShortType)) + assert(!doesCastAlwaysSucceed(StringType, IntegerType)) + assert(!doesCastAlwaysSucceed(StringType, LongType)) + assert(!doesCastAlwaysSucceed(StringType, DecimalType(10,10))) + assert(!doesCastAlwaysSucceed(StringType, DateType)) + assert(!doesCastAlwaysSucceed(StringType, TimestampType)) + assert(!doesCastAlwaysSucceed(StructType(Seq()), StructType(Seq()))) + assert(!doesCastAlwaysSucceed(ArrayType(StringType), ArrayType(StringType))) - assert(!isCastAlwaysSucceeds(ShortType, ByteType)) - assert(!isCastAlwaysSucceeds(IntegerType, ByteType)) - assert(!isCastAlwaysSucceeds(IntegerType, ShortType)) - assert(!isCastAlwaysSucceeds(LongType, ByteType)) - assert(!isCastAlwaysSucceeds(LongType, ShortType)) - assert(!isCastAlwaysSucceeds(LongType, IntegerType)) + assert(!doesCastAlwaysSucceed(ShortType, ByteType)) + assert(!doesCastAlwaysSucceed(IntegerType, ByteType)) + assert(!doesCastAlwaysSucceed(IntegerType, ShortType)) + assert(!doesCastAlwaysSucceed(LongType, ByteType)) + assert(!doesCastAlwaysSucceed(LongType, ShortType)) + assert(!doesCastAlwaysSucceed(LongType, IntegerType)) - assert(isCastAlwaysSucceeds(StringType, StringType)) - assert(isCastAlwaysSucceeds(ByteType, StringType)) - assert(isCastAlwaysSucceeds(ShortType, StringType)) - assert(isCastAlwaysSucceeds(IntegerType, StringType)) - assert(isCastAlwaysSucceeds(LongType, StringType)) - assert(isCastAlwaysSucceeds(DecimalType(10,10), StringType)) - assert(isCastAlwaysSucceeds(DateType, StringType)) - assert(isCastAlwaysSucceeds(TimestampType, StringType)) - assert(isCastAlwaysSucceeds(StringType, StringType)) + assert(doesCastAlwaysSucceed(StringType, StringType)) + assert(doesCastAlwaysSucceed(ByteType, StringType)) + assert(doesCastAlwaysSucceed(ShortType, StringType)) + assert(doesCastAlwaysSucceed(IntegerType, StringType)) + assert(doesCastAlwaysSucceed(LongType, StringType)) + assert(doesCastAlwaysSucceed(DecimalType(10,10), StringType)) + assert(doesCastAlwaysSucceed(DateType, StringType)) + assert(doesCastAlwaysSucceed(TimestampType, StringType)) + assert(doesCastAlwaysSucceed(StringType, StringType)) - assert(isCastAlwaysSucceeds(ByteType, ByteType)) - assert(isCastAlwaysSucceeds(ByteType, ShortType)) - assert(isCastAlwaysSucceeds(ByteType, IntegerType)) - assert(isCastAlwaysSucceeds(ByteType, LongType)) - assert(isCastAlwaysSucceeds(ShortType, ShortType)) - assert(isCastAlwaysSucceeds(ShortType, IntegerType)) - assert(isCastAlwaysSucceeds(ShortType, LongType)) - assert(isCastAlwaysSucceeds(IntegerType, IntegerType)) - assert(isCastAlwaysSucceeds(IntegerType, LongType)) - assert(isCastAlwaysSucceeds(LongType, LongType)) - assert(isCastAlwaysSucceeds(DateType, TimestampType)) + assert(doesCastAlwaysSucceed(ByteType, ByteType)) + assert(doesCastAlwaysSucceed(ByteType, ShortType)) + assert(doesCastAlwaysSucceed(ByteType, IntegerType)) + assert(doesCastAlwaysSucceed(ByteType, LongType)) + assert(doesCastAlwaysSucceed(ShortType, ShortType)) + assert(doesCastAlwaysSucceed(ShortType, IntegerType)) + assert(doesCastAlwaysSucceed(ShortType, LongType)) + assert(doesCastAlwaysSucceed(IntegerType, IntegerType)) + assert(doesCastAlwaysSucceed(IntegerType, LongType)) + assert(doesCastAlwaysSucceed(LongType, LongType)) + assert(doesCastAlwaysSucceed(DateType, TimestampType)) } test("Test isCommonSubPath()") { @@ -74,22 +74,4 @@ class SchemaUtilSuite extends AnyFunSuite with Matchers { assert (isCommonSubPath("a.b.c.d.e.f", "a.b.c.d", "a.b.c", "a.b", "a")) assert (!isCommonSubPath("a.b.c.d.e.f", "a.b.c.x", "a.b.c", "a.b", "a")) } - - test("unpath - empty string remains empty") { - val result = unpath("") - val expected = "" - assert(result == expected) - } - - test("unpath - underscores get doubled") { - val result = unpath("one_two__three") - val expected = "one__two____three" - assert(result == expected) - } - - test("unpath - dot notation conversion") { - val result = unpath("grand_parent.parent.first_child") - val expected = "grand__parent_parent_first__child" - assert(result == expected) - } } diff --git a/src/test/scala/za/co/absa/spark/commons/schema/SchemaUtilsSpec.scala b/src/test/scala/za/co/absa/spark/commons/schema/SchemaUtilsSpec.scala index 771e19ad..f17763b5 100644 --- a/src/test/scala/za/co/absa/spark/commons/schema/SchemaUtilsSpec.scala +++ b/src/test/scala/za/co/absa/spark/commons/schema/SchemaUtilsSpec.scala @@ -16,7 +16,6 @@ package za.co.absa.spark.commons.schema -import org.apache.spark.sql.AnalysisException import org.scalatest.BeforeAndAfterAll import org.scalatest.flatspec.AnyFlatSpec import org.scalatest.matchers.should.Matchers @@ -54,27 +53,6 @@ class SchemaUtilsSpec extends AnyFlatSpec with Matchers with BeforeAndAfterAll w SchemaUtils.equivalentSchemas(dfB.schema, dfA.schema) should be(false) } - behavior of "alignSchema" - - it should "order schemas for equal schemas" in { - val dfA = spark.read.json(Seq(jsonA).toDS) - val dfC = spark.read.json(Seq(jsonC).toDS).select("legs", "id", "key") - - val dfA2Aligned = SchemaUtils.alignSchema(dfC, dfA.schema) - - dfA.columns.toSeq should equal(dfA2Aligned.columns.toSeq) - dfA.select("key").columns.toSeq should equal(dfA2Aligned.select("key").columns.toSeq) - } - - it should "throw an error for DataFrames with different schemas" in { - val dfA = spark.read.json(Seq(jsonA).toDS) - val dfB = spark.read.json(Seq(jsonB).toDS) - - intercept[AnalysisException] { - SchemaUtils.alignSchema(dfA, dfB.schema) - } - } - behavior of "diffSchema" it should "produce a list of differences with path for schemas with different columns" in { From 09f3d205f7aee29c38f1e3526d96c5306df9445a Mon Sep 17 00:00:00 2001 From: Adrian Olosutean Date: Mon, 7 Feb 2022 15:58:46 +0100 Subject: [PATCH 16/19] #22 multiple changes to implicit classes --- README.md | 125 ++++++-- .../spark/commons/SparkVersionGuard.scala | 2 + .../implicits/ArrayTypeImplicits.scala | 65 +++- .../implicits/DataFrameImplicits.scala | 12 +- .../commons/implicits/DataTypeImplicits.scala | 89 ++++++ .../implicits/StructFieldImplicits.scala | 25 ++ .../implicits/StructTypeImplicits.scala | 92 +++++- .../spark/commons/schema/SchemaUtils.scala | 288 ------------------ .../spark/commons/utils/SchemaUtils.scala | 98 ++++++ .../implicits/ColumnImplicitsTest.scala | 2 +- ...ite.scala => DataFrameImplicitsTest.scala} | 6 +- .../implicits/DataTypeImplicitsTest.scala | 68 +++++ .../commons/implicits/JsonTestData.scala | 63 ++++ .../StructTypeImplicitsArrayTest.scala | 37 +-- .../implicits/StructTypeImplicitsTest.scala | 99 ++++-- .../commons/schema/SchemaUtilSuite.scala | 77 ----- .../commons/schema/SchemaUtilsSpec.scala | 103 ------- .../spark/commons/utils/SchemaUtilTest.scala | 40 +++ 18 files changed, 703 insertions(+), 588 deletions(-) create mode 100644 src/main/scala/za/co/absa/spark/commons/implicits/DataTypeImplicits.scala delete mode 100644 src/main/scala/za/co/absa/spark/commons/schema/SchemaUtils.scala create mode 100644 src/main/scala/za/co/absa/spark/commons/utils/SchemaUtils.scala rename src/test/scala/za/co/absa/spark/commons/implicits/{DataFrameImplicitsSuite.scala => DataFrameImplicitsTest.scala} (92%) create mode 100644 src/test/scala/za/co/absa/spark/commons/implicits/DataTypeImplicitsTest.scala create mode 100644 src/test/scala/za/co/absa/spark/commons/implicits/JsonTestData.scala delete mode 100644 src/test/scala/za/co/absa/spark/commons/schema/SchemaUtilSuite.scala delete mode 100644 src/test/scala/za/co/absa/spark/commons/schema/SchemaUtilsSpec.scala create mode 100644 src/test/scala/za/co/absa/spark/commons/utils/SchemaUtilTest.scala diff --git a/README.md b/README.md index cfbd3fb2..cfd76041 100644 --- a/README.md +++ b/README.md @@ -30,13 +30,13 @@ _Spark Schema Utils_ provides methods for working with schemas, its comparison a 1. Schema comparison returning true/false. Ignores the order of columns ```scala - SchemaUtils.equivalentSchemas(schema1, schema2) + SchemaUtils.equivalentSchemas(schema1, other) ``` 2. Schema comparison returning difference. Ignores the order of columns ```scala - SchemaUtils.diff(schema1, schema2) + SchemaUtils.diff(schema1, other) ``` 3. Schema selector generator which provides a List of columns to be used in a @@ -67,11 +67,11 @@ _ColumnImplicits_ provide implicit methods for transforming Spark Columns ```scala column.zeroBasedSubstr(startPos) ``` + 3. Returns column with requested substring. It shifts the substring indexation to be in accordance with Scala/ Java. If the provided starting position where to start the substring from is negative, it will be counted from end. The length of the desired substring, if longer then the rest of the string, all the remaining characters are taken. - ```scala column.zeroBasedSubstr(startPos, length) ``` @@ -80,12 +80,6 @@ _ColumnImplicits_ provide implicit methods for transforming Spark Columns _StructFieldImplicits_ provides implicit methods for working with StructField objects. -1. Determine the name of a field overriden by metadata - - ```scala - structField.getFieldNameOverriddenByMetadata() - ``` - Of them, metadata methods are: 1. Gets the metadata Option[String] value given a key @@ -113,6 +107,45 @@ Of them, metadata methods are: structField.metadata.hasKey(key) ``` +### ArrayTypeImplicits + +_ArrayTypeImplicits_ provides implicit methods for working with ArrayType objects. + + +1. Get a field from a text path + + ```scala + arrayType.isEquivalentArrayType(otherArrayType) + ``` + +2. For an array of arrays, get the final element type at the bottom of the array + + ```scala + arrayType.getDeepestArrayType() + ``` + +### DataTypeImplicits + +_DataTypeImplicits_ provides implicit methods for working with DataType objects. + + +1. Get a field from a text path + + ```scala + dataType.isEquivalentDataType(otherDt) + ``` + +2. Checks if a casting between types always succeeds + + ```scala + dataType.doesCastAlwaysSucceed(otherDt) + ``` +3. Checks if type is primitive + + ```scala + dataType.isPrimitive() + ``` + ### StructTypeImplicits _StructTypeImplicits_ provides implicit methods for working with StructType objects. @@ -145,62 +178,83 @@ _StructTypeImplicits_ provides implicit methods for working with StructType obje ```scala structType.fieldExists(path) ``` + +6. Get paths for all array fields in the schema -6. Returns all renames in the provided schema + ```scala + structType.getAllArrayPaths() + ``` + +7. Get a closest unique column name ```scala - structType.getRenamesInSchema(includeIfPredecessorChanged) + structType.getClosestUniqueName(desiredName) ``` -7. Get first array column's path out of complete path + +8. Checks if a field is the only field in a struct ```scala - structType.getFirstArrayPath(path) + structType.isOnlyField(columnName) ``` -8. Get all array columns' paths out of complete path. +9. Checks if 2 schemas are equivalent ```scala - structType.getAllArraysInPath(path) + structType.isEquivalent(other) ``` -9. For a given list of field paths determines the deepest common array path +10. Returns a list of differences in one utils to the other ```scala - structType.getDeepestCommonArrayPath(fieldPaths) + structType.diffSchema(otherSchema, parent) ``` -10. For a field path determines the deepest array path + +11. Checks if a field is of the specified type ```scala - structType.getDeepestArrayPath(path) + structType.isOfType[ArrayType](path) ``` +12. Checks if a field is a subset of the specified type + + ```scala + structType.isSubset(other) + ``` -11. Get paths for all array fields in the schema +13. Returns data selector that can be used to align utils of a data frame. ```scala - structType.getAllArrayPaths() + structType.getDataFrameSelector() ``` -12. Get a closest unique column name +###StructTypeArrayImplicits + +1. Get first array column's path out of complete path ```scala - structType.getClosestUniqueName(desiredName) + structType.getFirstArrayPath(path) ``` + +2. Get all array columns' paths out of complete path. -13. Checks if a field is the only field in a struct + ```scala + structType.getAllArraysInPath(path) + ``` + +3. For a given list of field paths determines the deepest common array path ```scala - structType.isOnlyField(columnName) + structType.getDeepestCommonArrayPath(fieldPaths) ``` -14. Checks if a field is an array that is not nested in another array +4. For a field path determines the deepest array path ```scala - structType.isNonNestedArray(path) + structType.getDeepestArrayPath(path) ``` - -15. Checks if a field is an array + +5. Checks if a field is an array that is not nested in another array ```scala - structType.isArray(path) + structType.isNonNestedArray(path) ``` # Spark Version Guard @@ -243,4 +297,15 @@ _DataFrameImplicits_ provides methods for transformations on Dataframes ```scala df.withColumnIfDoesNotExist((df: DataFrame, _) => df)(colName, colExpression) + ``` + +3. Aligns the utils of a DataFrame to the selector for operations + where utils order might be important (e.g. hashing the whole rows and using except) + + ```scala + df.alignSchema(structType) + ``` + + ```scala + df.alignSchema(listColumns) ``` \ No newline at end of file diff --git a/src/main/scala/za/co/absa/spark/commons/SparkVersionGuard.scala b/src/main/scala/za/co/absa/spark/commons/SparkVersionGuard.scala index a8e92046..d78c6d54 100644 --- a/src/main/scala/za/co/absa/spark/commons/SparkVersionGuard.scala +++ b/src/main/scala/za/co/absa/spark/commons/SparkVersionGuard.scala @@ -16,10 +16,12 @@ package za.co.absa.spark.commons +import org.apache.spark.sql.functions.col import org.slf4j.{Logger, LoggerFactory} import za.co.absa.commons.version.Version import za.co.absa.commons.version.Version.VersionStringInterpolator import za.co.absa.commons.version.impl.SemVer20Impl.SemanticVersion +import za.co.absa.spark.commons.implicits._ object SparkVersionGuard { val minSpark2XVersionIncluded: SemanticVersion = semver"2.4.2" diff --git a/src/main/scala/za/co/absa/spark/commons/implicits/ArrayTypeImplicits.scala b/src/main/scala/za/co/absa/spark/commons/implicits/ArrayTypeImplicits.scala index e1dae212..8c552f40 100644 --- a/src/main/scala/za/co/absa/spark/commons/implicits/ArrayTypeImplicits.scala +++ b/src/main/scala/za/co/absa/spark/commons/implicits/ArrayTypeImplicits.scala @@ -1,6 +1,23 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package za.co.absa.spark.commons.implicits -import org.apache.spark.sql.types.{ArrayType, DataType} +import org.apache.spark.sql.types.{ArrayType, DataType, StructType} +import za.co.absa.spark.commons.implicits.StructTypeImplicits.StructTypeEnhancements import scala.annotation.tailrec @@ -8,6 +25,50 @@ object ArrayTypeImplicits { implicit class ArrayTypeEnhancements(arrayType: ArrayType) { + /** + * Compares 2 array fields of a dataframe utils. + * + * @param other The second array to compare + * @return true if provided arrays are the same ignoring nullability + */ + @scala.annotation.tailrec + final def isEquivalentArrayType(other: ArrayType): Boolean = { + arrayType.elementType match { + case arrayType1: ArrayType => + other.elementType match { + case arrayType2: ArrayType => arrayType1.isEquivalentArrayType(arrayType2) + case _ => false + } + case structType1: StructType => + other.elementType match { + case structType2: StructType => structType1.isEquivalent(structType2) + case _ => false + } + case _ => arrayType.elementType == other.elementType + } + } + + + /** + * Finds all differences of two ArrayTypes and returns their paths + * + * @param array2 The second array to compare + * @param parent Parent path. This is used for the accumulation of differences and their print out + * @return Returns a Seq of found difference paths in scheme in the Array + */ + @scala.annotation.tailrec + private[implicits] final def diffArray(array2: ArrayType, parent: String): Seq[String] = { + arrayType.elementType match { + case _ if arrayType.elementType.typeName != array2.elementType.typeName => + Seq(s"$parent data type doesn't match (${arrayType.elementType.typeName}) vs (${array2.elementType.typeName})") + case arrayType1: ArrayType => + arrayType1.diffArray(array2.elementType.asInstanceOf[ArrayType], s"$parent") + case structType1: StructType => + structType1.diffSchema(array2.elementType.asInstanceOf[StructType], s"$parent") + case _ => Seq.empty[String] + } + } + /** * For an array of arrays of arrays, ... get the final element type at the bottom of the array * @@ -23,7 +84,5 @@ object ArrayTypeImplicits { } getDeepestArrayTypeHelper(arrayType) } - - } } diff --git a/src/main/scala/za/co/absa/spark/commons/implicits/DataFrameImplicits.scala b/src/main/scala/za/co/absa/spark/commons/implicits/DataFrameImplicits.scala index bc4357c3..723a2e51 100644 --- a/src/main/scala/za/co/absa/spark/commons/implicits/DataFrameImplicits.scala +++ b/src/main/scala/za/co/absa/spark/commons/implicits/DataFrameImplicits.scala @@ -77,20 +77,20 @@ object DataFrameImplicits { } /** - * Using schema selector returned from [[StructTypeEnhancements.getDataFrameSelector]] aligns the schema of a DataFrame to the selector - * for operations where schema order might be important (e.g. hashing the whole rows and using except) + * Using utils selector returned from [[StructTypeEnhancements.getDataFrameSelector]] aligns the utils of a DataFrame to the selector + * for operations where utils order might be important (e.g. hashing the whole rows and using except) * * @param selector model structType for the alignment of df - * @return Returns aligned and filtered schema + * @return Returns aligned and filtered utils */ def alignSchema(selector: List[Column]): DataFrame = df.select(selector: _*) /** - * Using schema selector from [[getDataFrameSelector]] aligns the schema of a DataFrame to the selector for operations - * where schema order might be important (e.g. hashing the whole rows and using except) + * Using utils selector from [[getDataFrameSelector]] aligns the utils of a DataFrame to the selector for operations + * where utils order might be important (e.g. hashing the whole rows and using except) * * @param structType model structType for the alignment of df - * @return Returns aligned and filtered schema + * @return Returns aligned and filtered utils */ def alignSchema(structType: StructType): DataFrame = alignSchema(structType.getDataFrameSelector()) } diff --git a/src/main/scala/za/co/absa/spark/commons/implicits/DataTypeImplicits.scala b/src/main/scala/za/co/absa/spark/commons/implicits/DataTypeImplicits.scala new file mode 100644 index 00000000..b3498005 --- /dev/null +++ b/src/main/scala/za/co/absa/spark/commons/implicits/DataTypeImplicits.scala @@ -0,0 +1,89 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.spark.commons.implicits + +import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, ByteType, DataType, DateType, DecimalType, DoubleType, FloatType, IntegerType, LongType, NullType, ShortType, StringType, StructType, TimestampType} +import za.co.absa.spark.commons.implicits.ArrayTypeImplicits.ArrayTypeEnhancements +import za.co.absa.spark.commons.implicits.StructTypeImplicits.StructTypeEnhancements + +object DataTypeImplicits { + + implicit class DataTypeEnhancements(dt: DataType) { + /** + * Compares 2 fields of a dataframe utils. + * + * @param other The second field to compare + * @return true if provided fields are the same ignoring nullability + */ + def isEquivalentDataType(other: DataType): Boolean = { + dt match { + case arrayType1: ArrayType => + other match { + case arrayType2: ArrayType => arrayType1.isEquivalentArrayType(arrayType2) + case _ => false + } + case structType1: StructType => + other match { + case structType2: StructType => structType1.isEquivalent(structType2) + case _ => false + } + case _ => dt == other + } + } + + /** + * Checks if a casting between types always succeeds + * + * @param targetType A type to be casted to + * @return true if casting never fails + */ + def doesCastAlwaysSucceed(targetType: DataType): Boolean = { + (dt, targetType) match { + case (_: StructType, _) | (_: ArrayType, _) => false + case (a, b) if a == b => true + case (_, _: StringType) => true + case (_: ByteType, _: ShortType | _: IntegerType | _: LongType) => true + case (_: ShortType, _: IntegerType | _: LongType) => true + case (_: IntegerType, _: LongType) => true + case (_: DateType, _: TimestampType) => true + case _ => false + } + } + + /** + * Determine if a datatype is a primitive one + */ + def isPrimitive(): Boolean = dt match { + case _: BinaryType + | _: BooleanType + | _: ByteType + | _: DateType + | _: DecimalType + | _: DoubleType + | _: FloatType + | _: IntegerType + | _: LongType + | _: NullType + | _: ShortType + | _: StringType + | _: TimestampType => true + case _ => false + } + } + + +} diff --git a/src/main/scala/za/co/absa/spark/commons/implicits/StructFieldImplicits.scala b/src/main/scala/za/co/absa/spark/commons/implicits/StructFieldImplicits.scala index ad3306af..24d2fdf5 100644 --- a/src/main/scala/za/co/absa/spark/commons/implicits/StructFieldImplicits.scala +++ b/src/main/scala/za/co/absa/spark/commons/implicits/StructFieldImplicits.scala @@ -17,6 +17,8 @@ package za.co.absa.spark.commons.implicits import org.apache.spark.sql.types._ +import za.co.absa.spark.commons.implicits.ArrayTypeImplicits.ArrayTypeEnhancements +import za.co.absa.spark.commons.implicits.StructTypeImplicits.StructTypeEnhancements import scala.util.Try @@ -45,4 +47,27 @@ object StructFieldImplicits { metadata.contains(key) } } + + implicit class StructFieldEnhamcenets(val structField: StructField) { + + /** + * Finds all differences of two StructFields and returns their paths + * + * @param other The second field to compare + * @param parent Parent path. This is used for the accumulation of differences and their print out + * @return Returns a Seq of found difference paths in scheme in the StructField + */ + private[implicits] def diffField(other: StructField, parent: String): Seq[String] = { + structField.dataType match { + case _ if structField.dataType.typeName != other.dataType.typeName => + Seq(s"$parent.${structField.name} data type doesn't match (${structField.dataType.typeName}) vs (${other.dataType.typeName})") + case arrayType1: ArrayType => + arrayType1.diffArray(other.dataType.asInstanceOf[ArrayType], s"$parent.${structField.name}") + case structType1: StructType => + structType1.diffSchema(other.dataType.asInstanceOf[StructType], s"$parent.${structField.name}") + case _ => + Seq.empty[String] + } + } + } } diff --git a/src/main/scala/za/co/absa/spark/commons/implicits/StructTypeImplicits.scala b/src/main/scala/za/co/absa/spark/commons/implicits/StructTypeImplicits.scala index bd8b2847..a14142e5 100644 --- a/src/main/scala/za/co/absa/spark/commons/implicits/StructTypeImplicits.scala +++ b/src/main/scala/za/co/absa/spark/commons/implicits/StructTypeImplicits.scala @@ -18,16 +18,19 @@ package za.co.absa.spark.commons.implicits import org.apache.spark.sql.Column import org.apache.spark.sql.functions.{col, struct} -import org.apache.spark.sql.types.{ArrayType, DataType, StructField, StructType} -import za.co.absa.spark.commons.schema.SchemaUtils.{getAllArraySubPaths, isCommonSubPath, transform} +import org.apache.spark.sql.types.{ArrayType, DataType, NullType, StructField, StructType} +import za.co.absa.spark.commons.implicits.DataTypeImplicits.DataTypeEnhancements +import za.co.absa.spark.commons.implicits.StructFieldImplicits.StructFieldEnhamcenets +import za.co.absa.spark.commons.utils.SchemaUtils.{getAllArraySubPaths, isCommonSubPath, transform} import scala.annotation.tailrec +import scala.reflect.runtime.universe._ import scala.util.Try object StructTypeImplicits { implicit class StructTypeEnhancements(val schema: StructType) { /** - * Get a field from a text path and a given schema + * Get a field from a text path and a given utils * @param path The dot-separated path to the field * @return Some(the requested field) or None if the field does not exist */ @@ -65,7 +68,7 @@ object StructTypeImplicits { } /** - * Get a type of a field from a text path and a given schema + * Get a type of a field from a text path and a given utils * * @param path The dot-separated path to the field * @return Some(the type of the field) or None if the field does not exist @@ -96,7 +99,7 @@ object StructTypeImplicits { } /** - * Get nullability of a field from a text path and a given schema + * Get nullability of a field from a text path and a given utils * * @param path The dot-separated path to the field * @return Some(nullable) or None if the field does not exist @@ -106,7 +109,7 @@ object StructTypeImplicits { } /** - * Checks if a field specified by a path and a schema exists + * Checks if a field specified by a path and a utils exists * @param path The dot-separated path to the field * @return True if the field exists false otherwise */ @@ -115,18 +118,18 @@ object StructTypeImplicits { } /** - * Get paths for all array fields in the schema + * Get paths for all array fields in the utils * - * @return Seq of dot separated paths of fields in the schema, which are of type Array + * @return Seq of dot separated paths of fields in the utils, which are of type Array */ def getAllArrayPaths(): Seq[String] = { schema.fields.flatMap(f => getAllArraySubPaths("", f.name, f.dataType)).toSeq } /** - * Returns data selector that can be used to align schema of a data frame. You can use [[alignSchema]]. + * Returns data selector that can be used to align utils of a data frame. You can use [[alignSchema]]. * - * @return Sorted DF to conform to schema + * @return Sorted DF to conform to utils */ def getDataFrameSelector(): List[Column] = { @@ -187,9 +190,72 @@ object StructTypeImplicits { field => field.fields.length == 1) } - def isOfType[T <: DataType](path: String) = { - val fieldType = getFieldType(path).getOrElse(false) - fieldType.isInstanceOf[T] + /** + * Compares 2 dataframe schemas. + * + * @param other The second utils to compare + * @return true if provided schemas are the same ignoring nullability + */ + def isEquivalent(other: StructType): Boolean = { + val currentfields = schema.sortBy(_.name.toLowerCase) + val fields2 = other.sortBy(_.name.toLowerCase) + + currentfields.size == fields2.size && + currentfields.zip(fields2).forall { + case (f1, f2) => + f1.name.equalsIgnoreCase(f2.name) && + f1.dataType.isEquivalentDataType( f2.dataType) + } + } + + /** + * Returns a list of differences in one utils to the other + * + * @param other The second utils to compare + * @param parent Parent path. Should be left default by the users first run. This is used for the accumulation of + * differences and their print out. + * @return Returns a Seq of paths to differences in schemas + */ + def diffSchema(other: StructType, parent: String = ""): Seq[String] = { + val fields1 = getMapOfFields(schema) + val fields2 = getMapOfFields(other) + + val diff = fields1.values.foldLeft(Seq.empty[String])((difference, field1) => { + val field1NameLc = field1.name.toLowerCase() + if (fields2.contains(field1NameLc)) { + val field2 = fields2(field1NameLc) + difference ++ field1.diffField(field2, parent) + } else { + difference ++ Seq(s"$parent.${field1.name} cannot be found in both schemas") + } + }) + + diff.map(_.stripPrefix(".")) + } + + /** + * Checks if the originalSchema is a subset of subsetSchema. + * + * @param originalSchema The utils that needs to have at least all t + * @return true if provided schemas are the same ignoring nullability + */ + def isSubset(originalSchema: StructType): Boolean = { + val subsetFields = getMapOfFields(schema) + val originalFields = getMapOfFields(originalSchema) + + subsetFields.forall(subsetField => + originalFields.contains(subsetField._1) && + subsetField._2.dataType.isEquivalentDataType(originalFields(subsetField._1).dataType)) + } + + private def getMapOfFields(schema: StructType): Map[String, StructField] = { + schema.map(field => field.name.toLowerCase() -> field).toMap + } + + def isOfType[T <: DataType](path: String)(implicit tag: TypeTag[T]): Boolean = { + val fieldType = getFieldType(path).getOrElse(NullType) + val classT: Class[_] = runtimeMirror(getClass.getClassLoader).runtimeClass(tag.tpe.typeSymbol.asClass) + classT.equals(fieldType.getClass) } protected def evaluateConditionsForField(structField: StructType, path: Seq[String], fieldPathName: String, diff --git a/src/main/scala/za/co/absa/spark/commons/schema/SchemaUtils.scala b/src/main/scala/za/co/absa/spark/commons/schema/SchemaUtils.scala deleted file mode 100644 index badd3d44..00000000 --- a/src/main/scala/za/co/absa/spark/commons/schema/SchemaUtils.scala +++ /dev/null @@ -1,288 +0,0 @@ -/* - * Copyright 2021 ABSA Group Limited - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package za.co.absa.spark.commons.schema - -import org.apache.spark.sql.types._ -import za.co.absa.spark.commons.adapters.HofsAdapter - - -object SchemaUtils extends HofsAdapter { - - /** - * Returns the parent path of a field. Returns an empty string if a root level field name is provided. - * - * @param columnName A fully qualified column name - * @return The parent column name or an empty string if the input column is a root level column - */ - def getParentPath(columnName: String): String = { - val index = columnName.lastIndexOf('.') - if (index > 0) { - columnName.substring(0, index) - } else { - "" - } - } - - /** - * Compares 2 array fields of a dataframe schema. - * - * @param array1 The first array to compare - * @param array2 The second array to compare - * @return true if provided arrays are the same ignoring nullability - */ - @scala.annotation.tailrec - private def equalArrayTypes(array1: ArrayType, array2: ArrayType): Boolean = { - array1.elementType match { - case arrayType1: ArrayType => - array2.elementType match { - case arrayType2: ArrayType => equalArrayTypes(arrayType1, arrayType2) - case _ => false - } - case structType1: StructType => - array2.elementType match { - case structType2: StructType => equivalentSchemas(structType1, structType2) - case _ => false - } - case _ => array1.elementType == array2.elementType - } - } - - /** - * Finds all differences of two ArrayTypes and returns their paths - * - * @param array1 The first array to compare - * @param array2 The second array to compare - * @param parent Parent path. This is used for the accumulation of differences and their print out - * @return Returns a Seq of found difference paths in scheme in the Array - */ - @scala.annotation.tailrec - private def diffArray(array1: ArrayType, array2: ArrayType, parent: String): Seq[String] = { - array1.elementType match { - case _ if array1.elementType.typeName != array2.elementType.typeName => - Seq(s"$parent data type doesn't match (${array1.elementType.typeName}) vs (${array2.elementType.typeName})") - case arrayType1: ArrayType => - diffArray(arrayType1, array2.elementType.asInstanceOf[ArrayType], s"$parent") - case structType1: StructType => - diffSchema(structType1, array2.elementType.asInstanceOf[StructType], s"$parent") - case _ => Seq.empty[String] - } - } - - /** - * Compares 2 fields of a dataframe schema. - * - * @param type1 The first field to compare - * @param type2 The second field to compare - * @return true if provided fields are the same ignoring nullability - */ - private def equivalentTypes(type1: DataType, type2: DataType): Boolean = { - type1 match { - case arrayType1: ArrayType => - type2 match { - case arrayType2: ArrayType => equalArrayTypes(arrayType1, arrayType2) - case _ => false - } - case structType1: StructType => - type2 match { - case structType2: StructType => equivalentSchemas(structType1, structType2) - case _ => false - } - case _ => type1 == type2 - } - } - - /** - * Determine if a datatype is a primitive one - */ - def isPrimitive(dt: DataType): Boolean = dt match { - case _: BinaryType - | _: BooleanType - | _: ByteType - | _: DateType - | _: DecimalType - | _: DoubleType - | _: FloatType - | _: IntegerType - | _: LongType - | _: NullType - | _: ShortType - | _: StringType - | _: TimestampType => true - case _ => false - } - - /** - * Finds all differences of two StructFields and returns their paths - * - * @param field1 The first field to compare - * @param field2 The second field to compare - * @param parent Parent path. This is used for the accumulation of differences and their print out - * @return Returns a Seq of found difference paths in scheme in the StructField - */ - private def diffField(field1: StructField, field2: StructField, parent: String): Seq[String] = { - field1.dataType match { - case _ if field1.dataType.typeName != field2.dataType.typeName => - Seq(s"$parent.${field1.name} data type doesn't match (${field1.dataType.typeName}) vs (${field2.dataType.typeName})") - case arrayType1: ArrayType => - diffArray(arrayType1, field2.dataType.asInstanceOf[ArrayType], s"$parent.${field1.name}") - case structType1: StructType => - diffSchema(structType1, field2.dataType.asInstanceOf[StructType], s"$parent.${field1.name}") - case _ => - Seq.empty[String] - } - } - - /** - * Compares 2 dataframe schemas. - * - * @param schema1 The first schema to compare - * @param schema2 The second schema to compare - * @return true if provided schemas are the same ignoring nullability - */ - def equivalentSchemas(schema1: StructType, schema2: StructType): Boolean = { - val fields1 = schema1.sortBy(_.name.toLowerCase) - val fields2 = schema2.sortBy(_.name.toLowerCase) - - fields1.size == fields2.size && - fields1.zip(fields2).forall { - case (f1, f2) => - f1.name.equalsIgnoreCase(f2.name) && - equivalentTypes(f1.dataType, f2.dataType) - } - } - - /** - * Returns a list of differences in one schema to the other - * - * @param schema1 The first schema to compare - * @param schema2 The second schema to compare - * @param parent Parent path. Should be left default by the users first run. This is used for the accumulation of - * differences and their print out. - * @return Returns a Seq of paths to differences in schemas - */ - def diffSchema(schema1: StructType, schema2: StructType, parent: String = ""): Seq[String] = { - val fields1 = getMapOfFields(schema1) - val fields2 = getMapOfFields(schema2) - - val diff = fields1.values.foldLeft(Seq.empty[String])((difference, field1) => { - val field1NameLc = field1.name.toLowerCase() - if (fields2.contains(field1NameLc)) { - val field2 = fields2(field1NameLc) - difference ++ diffField(field1, field2, parent) - } else { - difference ++ Seq(s"$parent.${field1.name} cannot be found in both schemas") - } - }) - - diff.map(_.stripPrefix(".")) - } - - /** - * Checks if the originalSchema is a subset of subsetSchema. - * - * @param subsetSchema The schema that needs to be extracted - * @param originalSchema The schema that needs to have at least all t - * @return true if provided schemas are the same ignoring nullability - */ - def isSubset(subsetSchema: StructType, originalSchema: StructType): Boolean = { - val subsetFields = getMapOfFields(subsetSchema) - val originalFields = getMapOfFields(originalSchema) - - subsetFields.forall(subsetField => - originalFields.contains(subsetField._1) && - equivalentTypes(subsetField._2.dataType, originalFields(subsetField._1).dataType)) - } - - private def getMapOfFields(schema: StructType): Map[String, StructField] = { - schema.map(field => field.name.toLowerCase() -> field).toMap - } - - /** - * Get paths for all array subfields of this given datatype - */ - def getAllArraySubPaths(path: String, name: String, dt: DataType): Seq[String] = { - val currPath = appendPath(path, name) - dt match { - case s: StructType => s.fields.flatMap(f => getAllArraySubPaths(currPath, f.name, f.dataType)) - case _@ArrayType(elType, _) => getAllArraySubPaths(path, name, elType) :+ currPath - case _ => Seq() - } - } - - /** - * For a given list of field paths determines if any path pair is a subset of one another. - * - * For instance, - * - 'a.b', 'a.b.c', 'a.b.c.d' have this property. - * - 'a.b', 'a.b.c', 'a.x.y' does NOT have it, since 'a.b.c' and 'a.x.y' have diverging subpaths. - * - * @param paths A list of paths to be analyzed - * @return true if for all pathe the above property holds - */ - def isCommonSubPath(paths: String*): Boolean = { - def sliceRoot(paths: Seq[Seq[String]]): Seq[Seq[String]] = { - paths.map(path => path.drop(1)).filter(_.nonEmpty) - } - - var isParentCommon = true // For Seq() the property holds by [my] convention - var restOfPaths: Seq[Seq[String]] = paths.map(_.split('.').toSeq).filter(_.nonEmpty) - while (isParentCommon && restOfPaths.nonEmpty) { - val parent = restOfPaths.head.head - isParentCommon = restOfPaths.forall(path => path.head == parent) - restOfPaths = sliceRoot(restOfPaths) - } - isParentCommon - } - - /** - * Append a new attribute to path or empty string. - * - * @param path The dot-separated existing path - * @param fieldName Name of the field to be appended to the path - * @return The path with the new field appended or the field itself if path is empty - */ - def appendPath(path: String, fieldName: String): String = { - if (path.isEmpty) { - fieldName - } else if (fieldName.isEmpty) { - path - } else { - s"$path.$fieldName" - } - } - - - /** - * Checks if a casting between types always succeeds - * - * @param sourceType A type to be casted - * @param targetType A type to be casted to - * @return true if casting never fails - */ - def doesCastAlwaysSucceed(sourceType: DataType, targetType: DataType): Boolean = { - (sourceType, targetType) match { - case (_: StructType, _) | (_: ArrayType, _) => false - case (a, b) if a == b => true - case (_, _: StringType) => true - case (_: ByteType, _: ShortType | _: IntegerType | _: LongType) => true - case (_: ShortType, _: IntegerType | _: LongType) => true - case (_: IntegerType, _: LongType) => true - case (_: DateType, _: TimestampType) => true - case _ => false - } - } -} diff --git a/src/main/scala/za/co/absa/spark/commons/utils/SchemaUtils.scala b/src/main/scala/za/co/absa/spark/commons/utils/SchemaUtils.scala new file mode 100644 index 00000000..fcf1e8d6 --- /dev/null +++ b/src/main/scala/za/co/absa/spark/commons/utils/SchemaUtils.scala @@ -0,0 +1,98 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.spark.commons.utils + +import org.apache.spark.sql.types._ +import za.co.absa.spark.commons.adapters.HofsAdapter + + +object SchemaUtils extends HofsAdapter { + + /** + * Returns the parent path of a field. Returns an empty string if a root level field name is provided. + * + * @param columnName A fully qualified column name + * @return The parent column name or an empty string if the input column is a root level column + */ + def getParentPath(columnName: String): String = { + val index = columnName.lastIndexOf('.') + if (index > 0) { + columnName.substring(0, index) + } else { + "" + } + } + + /** + * Get paths for all array subfields of this given datatype + * @param path The path to the attribute + * @param name The name of the attribute + * @param dt Data type to be checked + * + */ + def getAllArraySubPaths(path: String, name: String, dt: DataType): Seq[String] = { + val currPath = appendPath(path, name) + dt match { + case s: StructType => s.fields.flatMap(f => getAllArraySubPaths(currPath, f.name, f.dataType)) + case _@ArrayType(elType, _) => getAllArraySubPaths(path, name, elType) :+ currPath + case _ => Seq() + } + } + + /** + * For a given list of field paths determines if any path pair is a subset of one another. + * + * For instance, + * - 'a.b', 'a.b.c', 'a.b.c.d' have this property. + * - 'a.b', 'a.b.c', 'a.x.y' does NOT have it, since 'a.b.c' and 'a.x.y' have diverging subpaths. + * + * @param paths A list of paths to be analyzed + * @return true if for all pathe the above property holds + */ + def isCommonSubPath(paths: String*): Boolean = { + def sliceRoot(paths: Seq[Seq[String]]): Seq[Seq[String]] = { + paths.map(path => path.drop(1)).filter(_.nonEmpty) + } + + var isParentCommon = true // For Seq() the property holds by [my] convention + var restOfPaths: Seq[Seq[String]] = paths.map(_.split('.').toSeq).filter(_.nonEmpty) + while (isParentCommon && restOfPaths.nonEmpty) { + val parent = restOfPaths.head.head + isParentCommon = restOfPaths.forall(path => path.head == parent) + restOfPaths = sliceRoot(restOfPaths) + } + isParentCommon + } + + /** + * Append a new attribute to path or empty string. + * + * @param path The dot-separated existing path + * @param fieldName Name of the field to be appended to the path + * @return The path with the new field appended or the field itself if path is empty + */ + def appendPath(path: String, fieldName: String): String = { + if (path.isEmpty) { + fieldName + } else if (fieldName.isEmpty) { + path + } else { + s"$path.$fieldName" + } + } + +} diff --git a/src/test/scala/za/co/absa/spark/commons/implicits/ColumnImplicitsTest.scala b/src/test/scala/za/co/absa/spark/commons/implicits/ColumnImplicitsTest.scala index 20d1e62b..f075eafa 100644 --- a/src/test/scala/za/co/absa/spark/commons/implicits/ColumnImplicitsTest.scala +++ b/src/test/scala/za/co/absa/spark/commons/implicits/ColumnImplicitsTest.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.functions.lit import org.scalatest.funsuite.AnyFunSuite import za.co.absa.spark.commons.implicits.ColumnImplicits.ColumnEnhancements -class ColumnImplicitsTest extends AnyFunSuite{ +class ColumnImplicitsTest extends AnyFunSuite { private val column: Column = lit("abcdefgh") diff --git a/src/test/scala/za/co/absa/spark/commons/implicits/DataFrameImplicitsSuite.scala b/src/test/scala/za/co/absa/spark/commons/implicits/DataFrameImplicitsTest.scala similarity index 92% rename from src/test/scala/za/co/absa/spark/commons/implicits/DataFrameImplicitsSuite.scala rename to src/test/scala/za/co/absa/spark/commons/implicits/DataFrameImplicitsTest.scala index 5eb595d6..412a5613 100644 --- a/src/test/scala/za/co/absa/spark/commons/implicits/DataFrameImplicitsSuite.scala +++ b/src/test/scala/za/co/absa/spark/commons/implicits/DataFrameImplicitsTest.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.functions.lit import org.scalatest.funsuite.AnyFunSuite import za.co.absa.spark.commons.test.SparkTestBase -class DataFrameImplicitsSuite extends AnyFunSuite with SparkTestBase { +class DataFrameImplicitsTest extends AnyFunSuite with SparkTestBase with JsonTestData { import spark.implicits._ private val columnName = "data" @@ -55,10 +55,6 @@ class DataFrameImplicitsSuite extends AnyFunSuite with SparkTestBase { "z" ) - val jsonA = """[{"id":1,"legs":[{"legid":100,"conditions":[{"checks":[{"checkNums":["1","2","3b","4","5c","6"]}],"amount":100}]}], "key" : {"alfa": "1", "beta": {"beta2": "2"}} }]""" - val jsonB = """[{"id":1,"legs":[{"legid":100,"conditions":[{"checks":[{"checkNums":["1","2","3b","4","5c","6"]}],"amount":100,"price":10}]}]}]""" - val jsonC = """[{"legs":[{"legid":100,"conditions":[{"amount":100,"checks":[{"checkNums":["1","2","3b","4","5c","6"]}]}]}],"id":1, "key" : {"beta": {"beta2": "2"}, "alfa": "1"} }]""" - private val inputData = inputDataSeq.toDF(columnName) import za.co.absa.spark.commons.implicits.DataFrameImplicits.DataFrameEnhancements diff --git a/src/test/scala/za/co/absa/spark/commons/implicits/DataTypeImplicitsTest.scala b/src/test/scala/za/co/absa/spark/commons/implicits/DataTypeImplicitsTest.scala new file mode 100644 index 00000000..445f14ce --- /dev/null +++ b/src/test/scala/za/co/absa/spark/commons/implicits/DataTypeImplicitsTest.scala @@ -0,0 +1,68 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.spark.commons.implicits + +import org.apache.spark.sql.types.{ArrayType, ByteType, DateType, DecimalType, IntegerType, LongType, ShortType, StringType, StructType, TimestampType} +import org.scalatest.funsuite.AnyFunSuite +import za.co.absa.spark.commons.implicits.DataTypeImplicits.DataTypeEnhancements + +class DataTypeImplicitsTest extends AnyFunSuite with JsonTestData { + + test ("Test doesCastAlwaysSucceed()") { + assert(!StructType(Seq()).doesCastAlwaysSucceed(StringType)) + assert(!ArrayType(StringType).doesCastAlwaysSucceed(StringType)) + assert(!StringType.doesCastAlwaysSucceed(ByteType)) + assert(!StringType.doesCastAlwaysSucceed(ShortType)) + assert(!StringType.doesCastAlwaysSucceed(IntegerType)) + assert(!StringType.doesCastAlwaysSucceed(LongType)) + assert(!StringType.doesCastAlwaysSucceed(DecimalType(10,10))) + assert(!StringType.doesCastAlwaysSucceed(DateType)) + assert(!StringType.doesCastAlwaysSucceed(TimestampType)) + assert(!StructType(Seq()).doesCastAlwaysSucceed(StructType(Seq()))) + assert(!ArrayType(StringType).doesCastAlwaysSucceed(ArrayType(StringType))) + + assert(!ShortType.doesCastAlwaysSucceed(ByteType)) + assert(!IntegerType.doesCastAlwaysSucceed(ByteType)) + assert(!IntegerType.doesCastAlwaysSucceed(ShortType)) + assert(!LongType.doesCastAlwaysSucceed(ByteType)) + assert(!LongType.doesCastAlwaysSucceed(ShortType)) + assert(!LongType.doesCastAlwaysSucceed(IntegerType)) + + assert(StringType.doesCastAlwaysSucceed(StringType)) + assert(ByteType.doesCastAlwaysSucceed(StringType)) + assert(ShortType.doesCastAlwaysSucceed(StringType)) + assert(IntegerType.doesCastAlwaysSucceed(StringType)) + assert(LongType.doesCastAlwaysSucceed(StringType)) + assert(DecimalType(10,10).doesCastAlwaysSucceed(StringType)) + assert(DateType.doesCastAlwaysSucceed(StringType)) + assert(TimestampType.doesCastAlwaysSucceed(StringType)) + assert(StringType.doesCastAlwaysSucceed(StringType)) + + assert(ByteType.doesCastAlwaysSucceed(ByteType)) + assert(ByteType.doesCastAlwaysSucceed(ShortType)) + assert(ByteType.doesCastAlwaysSucceed(IntegerType)) + assert(ByteType.doesCastAlwaysSucceed(LongType)) + assert(ShortType.doesCastAlwaysSucceed(ShortType)) + assert(ShortType.doesCastAlwaysSucceed(IntegerType)) + assert(ShortType.doesCastAlwaysSucceed(LongType)) + assert(IntegerType.doesCastAlwaysSucceed(IntegerType)) + assert(IntegerType.doesCastAlwaysSucceed(LongType)) + assert(LongType.doesCastAlwaysSucceed(LongType)) + assert(DateType.doesCastAlwaysSucceed(TimestampType)) + } + +} diff --git a/src/test/scala/za/co/absa/spark/commons/implicits/JsonTestData.scala b/src/test/scala/za/co/absa/spark/commons/implicits/JsonTestData.scala new file mode 100644 index 00000000..4a467703 --- /dev/null +++ b/src/test/scala/za/co/absa/spark/commons/implicits/JsonTestData.scala @@ -0,0 +1,63 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.spark.commons.implicits + +import org.apache.spark.sql.types.{ArrayType, IntegerType, StringType, StructField, StructType} + +trait JsonTestData { + + protected val jsonA = """[{"id":1,"legs":[{"legid":100,"conditions":[{"checks":[{"checkNums":["1","2","3b","4","5c","6"]}],"amount":100}]}], "key" : {"alfa": "1", "beta": {"beta2": "2"}} }]""" + protected val jsonB = """[{"id":1,"legs":[{"legid":100,"conditions":[{"checks":[{"checkNums":["1","2","3b","4","5c","6"]}],"amount":100,"price":10}]}]}]""" + protected val jsonC = """[{"legs":[{"legid":100,"conditions":[{"amount":100,"checks":[{"checkNums":["1","2","3b","4","5c","6"]}]}]}],"id":1, "key" : {"beta": {"beta2": "2"}, "alfa": "1"} }]""" + protected val jsonD = """[{"legs":[{"legid":100,"conditions":[{"amount":100,"checks":[{"checkNums":["1","2","3b","4","5c","6"]}]}]}],"id":1, "key" : {"beta": {"beta2": 2}, "alfa": 1} }]""" + protected val jsonE = """[{"legs":[{"legid":100,"conditions":[{"amount":100,"checks":[{"checkNums":["1","2","3b","4","5c","6"]}]}]}],"id":1, "key" : {"beta": {"beta2": 2}, "alfa": 1}, "extra" : "a"}]""" + + protected val sample = + """{"id":1,"legs":[{"legid":100,"conditions":[{"checks":[{"checkNums":["1","2","3b","4","5c","6"]}],"amount":100}]}]}""" :: + """{"id":2,"legs":[{"legid":200,"conditions":[{"checks":[{"checkNums":["8","9","10b","11","12c","13"]}],"amount":200}]}]}""" :: + """{"id":3,"legs":[{"legid":300,"conditions":[{"checks":[],"amount": 300}]}]}""" :: + """{"id":4,"legs":[{"legid":400,"conditions":[{"checks":null,"amount": 400}]}]}""" :: + """{"id":5,"legs":[{"legid":500,"conditions":[]}]}""" :: + """{"id":6,"legs":[]}""" :: + """{"id":7}""" :: Nil + + protected val schema = StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", StructType(Seq( + StructField("c", IntegerType), + StructField("d", StructType(Seq( + StructField("e", IntegerType))), nullable = true)))), + StructField("f", StructType(Seq( + StructField("g", ArrayType.apply(StructType(Seq( + StructField("h", IntegerType)))))))))) + + protected val nestedSchema = StructType(Seq( + StructField("a", IntegerType), + StructField("b", ArrayType(StructType(Seq( + StructField("c", StructType(Seq( + StructField("d", ArrayType(StructType(Seq( + StructField("e", IntegerType)))))))))))))) + + protected val arrayOfArraysSchema = StructType(Seq( + StructField("a", ArrayType(ArrayType(IntegerType)), nullable = false), + StructField("b", ArrayType(ArrayType(StructType(Seq( + StructField("c", StringType, nullable = false) + )) + )), nullable = true) + )) + +} diff --git a/src/test/scala/za/co/absa/spark/commons/implicits/StructTypeImplicitsArrayTest.scala b/src/test/scala/za/co/absa/spark/commons/implicits/StructTypeImplicitsArrayTest.scala index 688ebc59..b1bbc226 100644 --- a/src/test/scala/za/co/absa/spark/commons/implicits/StructTypeImplicitsArrayTest.scala +++ b/src/test/scala/za/co/absa/spark/commons/implicits/StructTypeImplicitsArrayTest.scala @@ -16,29 +16,15 @@ package za.co.absa.spark.commons.implicits -import org.apache.spark.sql.types.{ArrayType, IntegerType, StringType, StructField, StructType} +import org.apache.spark.sql.types.{ArrayType, LongType, StringType, StructField, StructType} import org.scalatest.funsuite.AnyFunSuite import za.co.absa.spark.commons.implicits.StructTypeImplicits.StructTypeEnhancementsArrays import za.co.absa.spark.commons.test.SparkTestBase -class StructTypeImplicitsArrayTest extends AnyFunSuite with SparkTestBase { - - private val schema = StructType(Seq( - StructField("a", IntegerType, nullable = false), - StructField("b", StructType(Seq( - StructField("c", IntegerType), - StructField("d", StructType(Seq( - StructField("e", IntegerType))), nullable = true)))), - StructField("f", StructType(Seq( - StructField("g", ArrayType.apply(StructType(Seq( - StructField("h", IntegerType)))))))))) - - private val nestedSchema = StructType(Seq( - StructField("a", IntegerType), - StructField("b", ArrayType(StructType(Seq( - StructField("c", StructType(Seq( - StructField("d", ArrayType(StructType(Seq( - StructField("e", IntegerType)))))))))))))) +class StructTypeImplicitsArrayTest extends AnyFunSuite with SparkTestBase with JsonTestData { + + import spark.implicits._ + val df = spark.read.json(sample.toDS) test("Testing getFirstArrayPath") { assertResult("f.g")(schema.getFirstArrayPath("f.g.h")) @@ -51,17 +37,6 @@ class StructTypeImplicitsArrayTest extends AnyFunSuite with SparkTestBase { assertResult(Seq("b", "b.c.d"))(nestedSchema.getAllArraysInPath("b.c.d.e")) } - val sample = - """{"id":1,"legs":[{"legid":100,"conditions":[{"checks":[{"checkNums":["1","2","3b","4","5c","6"]}],"amount":100}]}]}""" :: - """{"id":2,"legs":[{"legid":200,"conditions":[{"checks":[{"checkNums":["8","9","10b","11","12c","13"]}],"amount":200}]}]}""" :: - """{"id":3,"legs":[{"legid":300,"conditions":[{"checks":[],"amount": 300}]}]}""" :: - """{"id":4,"legs":[{"legid":400,"conditions":[{"checks":null,"amount": 400}]}]}""" :: - """{"id":5,"legs":[{"legid":500,"conditions":[]}]}""" :: - """{"id":6,"legs":[]}""" :: - """{"id":7}""" :: Nil - import spark.implicits._ - val df = spark.read.json(sample.toDS) - test("Test isNonNestedArray") { assert(df.schema.isNonNestedArray("legs")) assert(!df.schema.isNonNestedArray("legs.conditions")) @@ -71,7 +46,7 @@ class StructTypeImplicitsArrayTest extends AnyFunSuite with SparkTestBase { assert(!df.schema.isNonNestedArray("legs.legid")) } - test("Test isNonArray") { + test("Test isOfType ArrayType") { assert(df.schema.isOfType[ArrayType]("legs")) assert(df.schema.isOfType[ArrayType]("legs.conditions")) assert(df.schema.isOfType[ArrayType]("legs.conditions.checks")) diff --git a/src/test/scala/za/co/absa/spark/commons/implicits/StructTypeImplicitsTest.scala b/src/test/scala/za/co/absa/spark/commons/implicits/StructTypeImplicitsTest.scala index 58664a77..017b3bf3 100644 --- a/src/test/scala/za/co/absa/spark/commons/implicits/StructTypeImplicitsTest.scala +++ b/src/test/scala/za/co/absa/spark/commons/implicits/StructTypeImplicitsTest.scala @@ -21,39 +21,9 @@ import org.scalatest.funsuite.AnyFunSuite import za.co.absa.spark.commons.implicits.StructTypeImplicits.StructTypeEnhancements import za.co.absa.spark.commons.test.SparkTestBase -class StructTypeImplicitsTest extends AnyFunSuite with SparkTestBase { +class StructTypeImplicitsTest extends AnyFunSuite with SparkTestBase with JsonTestData { // scalastyle:off magic.number - private val schema = StructType(Seq( - StructField("a", IntegerType, nullable = false), - StructField("b", StructType(Seq( - StructField("c", IntegerType), - StructField("d", StructType(Seq( - StructField("e", IntegerType))), nullable = true)))), - StructField("f", StructType(Seq( - StructField("g", ArrayType.apply(StructType(Seq( - StructField("h", IntegerType)))))))))) - - private val nestedSchema = StructType(Seq( - StructField("a", IntegerType), - StructField("b", ArrayType(StructType(Seq( - StructField("c", StructType(Seq( - StructField("d", ArrayType(StructType(Seq( - StructField("e", IntegerType)))))))))))))) - - private val arrayOfArraysSchema = StructType(Seq( - StructField("a", ArrayType(ArrayType(IntegerType)), nullable = false), - StructField("b", ArrayType(ArrayType(StructType(Seq( - StructField("c", StringType, nullable = false) - )) - )), nullable = true) - )) - - private val structFieldNoMetadata = StructField("a", IntegerType) - - private val structFieldWithMetadataNotSourceColumn = StructField("a", IntegerType, nullable = false, new MetadataBuilder().putString("meta", "data").build) - private val structFieldWithMetadataSourceColumn = StructField("a", IntegerType, nullable = false, new MetadataBuilder().putString("sourcecolumn", "override_a").build) - test("Testing getFieldType") { import za.co.absa.spark.commons.implicits.StructTypeImplicits.StructTypeEnhancements @@ -179,4 +149,71 @@ class StructTypeImplicitsTest extends AnyFunSuite with SparkTestBase { assert(name2 == "value_3") } + import spark.implicits._ + + test("be true for E in D but not vice versa") { + val schemaD = spark.read.json(Seq(jsonD).toDS).schema + val schemaE = spark.read.json(Seq(jsonE).toDS).schema + + assert(schemaD.isSubset(schemaE)) + assert(!schemaE.isSubset(schemaD)) + } + + test("be false for A and B in both directions"){ + val schemaA = spark.read.json(Seq(jsonA).toDS).schema + val schemaB = spark.read.json(Seq(jsonA).toDS).schema + + assert(schemaA.isSubset(schemaB)) + assert(schemaB.isSubset(schemaA)) + } + + test("say true for the same schemas") { + val dfA1 = spark.read.json(Seq(jsonA).toDS) + val dfA2 = spark.read.json(Seq(jsonA).toDS) + + assert(dfA1.schema.isEquivalent(dfA2.schema)) + } + + test("say false when first utils has an extra field") { + val dfA = spark.read.json(Seq(jsonA).toDS) + val dfB = spark.read.json(Seq(jsonB).toDS) + + assert(!dfA.schema.isEquivalent(dfB.schema)) + } + + test("say false when second utils has an extra field") { + val dfA = spark.read.json(Seq(jsonA).toDS) + val dfB = spark.read.json(Seq(jsonB).toDS) + + assert(!dfB.schema.isEquivalent(dfA.schema)) + } + + test("produce a list of differences with path for schemas with different columns") { + val schemaA = spark.read.json(Seq(jsonA).toDS).schema + val schemaB = spark.read.json(Seq(jsonB).toDS).schema + + assertResult(schemaA.diffSchema(schemaB))(List("key cannot be found in both schemas")) + assertResult(schemaB.diffSchema(schemaA))(List("legs.conditions.price cannot be found in both schemas")) + } + + test("produce a list of differences with path for schemas with different column types") { + val schemaC = spark.read.json(Seq(jsonC).toDS).schema + val schemaD = spark.read.json(Seq(jsonD).toDS).schema + + val result = List( + "key.alfa data type doesn't match (string) vs (long)", + "key.beta.beta2 data type doesn't match (string) vs (long)" + ) + + assertResult(schemaC.diffSchema(schemaD))(result) + } + + test("produce an empty list for identical schemas") { + val schemaA = spark.read.json(Seq(jsonA).toDS).schema + val schemaB = spark.read.json(Seq(jsonA).toDS).schema + + assert(schemaA.diffSchema(schemaB).isEmpty) + assert(schemaB.diffSchema(schemaA).isEmpty) + } + } diff --git a/src/test/scala/za/co/absa/spark/commons/schema/SchemaUtilSuite.scala b/src/test/scala/za/co/absa/spark/commons/schema/SchemaUtilSuite.scala deleted file mode 100644 index 6ddf76de..00000000 --- a/src/test/scala/za/co/absa/spark/commons/schema/SchemaUtilSuite.scala +++ /dev/null @@ -1,77 +0,0 @@ -/* - * Copyright 2021 ABSA Group Limited - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package za.co.absa.spark.commons.schema - -import org.apache.spark.sql.types._ -import org.scalatest.funsuite.AnyFunSuite -import org.scalatest.matchers.should.Matchers -import za.co.absa.spark.commons.schema.SchemaUtils._ - -class SchemaUtilSuite extends AnyFunSuite with Matchers { - // scalastyle:off magic.number - - - test ("Test doesCastAlwaysSucceed()") { - assert(!doesCastAlwaysSucceed(StructType(Seq()), StringType)) - assert(!doesCastAlwaysSucceed(ArrayType(StringType), StringType)) - assert(!doesCastAlwaysSucceed(StringType, ByteType)) - assert(!doesCastAlwaysSucceed(StringType, ShortType)) - assert(!doesCastAlwaysSucceed(StringType, IntegerType)) - assert(!doesCastAlwaysSucceed(StringType, LongType)) - assert(!doesCastAlwaysSucceed(StringType, DecimalType(10,10))) - assert(!doesCastAlwaysSucceed(StringType, DateType)) - assert(!doesCastAlwaysSucceed(StringType, TimestampType)) - assert(!doesCastAlwaysSucceed(StructType(Seq()), StructType(Seq()))) - assert(!doesCastAlwaysSucceed(ArrayType(StringType), ArrayType(StringType))) - - assert(!doesCastAlwaysSucceed(ShortType, ByteType)) - assert(!doesCastAlwaysSucceed(IntegerType, ByteType)) - assert(!doesCastAlwaysSucceed(IntegerType, ShortType)) - assert(!doesCastAlwaysSucceed(LongType, ByteType)) - assert(!doesCastAlwaysSucceed(LongType, ShortType)) - assert(!doesCastAlwaysSucceed(LongType, IntegerType)) - - assert(doesCastAlwaysSucceed(StringType, StringType)) - assert(doesCastAlwaysSucceed(ByteType, StringType)) - assert(doesCastAlwaysSucceed(ShortType, StringType)) - assert(doesCastAlwaysSucceed(IntegerType, StringType)) - assert(doesCastAlwaysSucceed(LongType, StringType)) - assert(doesCastAlwaysSucceed(DecimalType(10,10), StringType)) - assert(doesCastAlwaysSucceed(DateType, StringType)) - assert(doesCastAlwaysSucceed(TimestampType, StringType)) - assert(doesCastAlwaysSucceed(StringType, StringType)) - - assert(doesCastAlwaysSucceed(ByteType, ByteType)) - assert(doesCastAlwaysSucceed(ByteType, ShortType)) - assert(doesCastAlwaysSucceed(ByteType, IntegerType)) - assert(doesCastAlwaysSucceed(ByteType, LongType)) - assert(doesCastAlwaysSucceed(ShortType, ShortType)) - assert(doesCastAlwaysSucceed(ShortType, IntegerType)) - assert(doesCastAlwaysSucceed(ShortType, LongType)) - assert(doesCastAlwaysSucceed(IntegerType, IntegerType)) - assert(doesCastAlwaysSucceed(IntegerType, LongType)) - assert(doesCastAlwaysSucceed(LongType, LongType)) - assert(doesCastAlwaysSucceed(DateType, TimestampType)) - } - - test("Test isCommonSubPath()") { - assert (isCommonSubPath()) - assert (isCommonSubPath("a")) - assert (isCommonSubPath("a.b.c.d.e.f", "a.b.c.d", "a.b.c", "a.b", "a")) - assert (!isCommonSubPath("a.b.c.d.e.f", "a.b.c.x", "a.b.c", "a.b", "a")) - } -} diff --git a/src/test/scala/za/co/absa/spark/commons/schema/SchemaUtilsSpec.scala b/src/test/scala/za/co/absa/spark/commons/schema/SchemaUtilsSpec.scala deleted file mode 100644 index f17763b5..00000000 --- a/src/test/scala/za/co/absa/spark/commons/schema/SchemaUtilsSpec.scala +++ /dev/null @@ -1,103 +0,0 @@ -/* - * Copyright 2021 ABSA Group Limited - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package za.co.absa.spark.commons.schema - -import org.scalatest.BeforeAndAfterAll -import org.scalatest.flatspec.AnyFlatSpec -import org.scalatest.matchers.should.Matchers -import za.co.absa.spark.commons.test.SparkTestBase - -class SchemaUtilsSpec extends AnyFlatSpec with Matchers with BeforeAndAfterAll with SparkTestBase { - import spark.implicits._ - - val jsonA = """[{"id":1,"legs":[{"legid":100,"conditions":[{"checks":[{"checkNums":["1","2","3b","4","5c","6"]}],"amount":100}]}], "key" : {"alfa": "1", "beta": {"beta2": "2"}} }]""" - val jsonB = """[{"id":1,"legs":[{"legid":100,"conditions":[{"checks":[{"checkNums":["1","2","3b","4","5c","6"]}],"amount":100,"price":10}]}]}]""" - val jsonC = """[{"legs":[{"legid":100,"conditions":[{"amount":100,"checks":[{"checkNums":["1","2","3b","4","5c","6"]}]}]}],"id":1, "key" : {"beta": {"beta2": "2"}, "alfa": "1"} }]""" - val jsonD = """[{"legs":[{"legid":100,"conditions":[{"amount":100,"checks":[{"checkNums":["1","2","3b","4","5c","6"]}]}]}],"id":1, "key" : {"beta": {"beta2": 2}, "alfa": 1} }]""" - val jsonE = """[{"legs":[{"legid":100,"conditions":[{"amount":100,"checks":[{"checkNums":["1","2","3b","4","5c","6"]}]}]}],"id":1, "key" : {"beta": {"beta2": 2}, "alfa": 1}, "extra" : "a"}]""" - - behavior of "isSameSchema" - - it should "say true for the same schemas" in { - val dfA1 = spark.read.json(Seq(jsonA).toDS) - val dfA2 = spark.read.json(Seq(jsonA).toDS) - - SchemaUtils.equivalentSchemas(dfA1.schema, dfA2.schema) should be(true) - } - - it should "say false when first schema has an extra field" in { - val dfA = spark.read.json(Seq(jsonA).toDS) - val dfB = spark.read.json(Seq(jsonB).toDS) - - SchemaUtils.equivalentSchemas(dfA.schema, dfB.schema) should be(false) - } - - it should "say false when second schema has an extra field" in { - val dfA = spark.read.json(Seq(jsonA).toDS) - val dfB = spark.read.json(Seq(jsonB).toDS) - - SchemaUtils.equivalentSchemas(dfB.schema, dfA.schema) should be(false) - } - - behavior of "diffSchema" - - it should "produce a list of differences with path for schemas with different columns" in { - val schemaA = spark.read.json(Seq(jsonA).toDS).schema - val schemaB = spark.read.json(Seq(jsonB).toDS).schema - - SchemaUtils.diffSchema(schemaA, schemaB) should equal(List("key cannot be found in both schemas")) - SchemaUtils.diffSchema(schemaB, schemaA) should equal(List("legs.conditions.price cannot be found in both schemas")) - } - - it should "produce a list of differences with path for schemas with different column types" in { - val schemaC = spark.read.json(Seq(jsonC).toDS).schema - val schemaD = spark.read.json(Seq(jsonD).toDS).schema - - val result = List( - "key.alfa data type doesn't match (string) vs (long)", - "key.beta.beta2 data type doesn't match (string) vs (long)" - ) - - SchemaUtils.diffSchema(schemaC, schemaD) should equal(result) - } - - it should "produce an empty list for identical schemas" in { - val schemaA = spark.read.json(Seq(jsonA).toDS).schema - val schemaB = spark.read.json(Seq(jsonA).toDS).schema - - SchemaUtils.diffSchema(schemaA, schemaB).isEmpty should be(true) - SchemaUtils.diffSchema(schemaB, schemaA).isEmpty should be(true) - } - - behavior of "isSubset" - - it should "be true for E in D but not vice versa" in { - val schemaD = spark.read.json(Seq(jsonD).toDS).schema - val schemaE = spark.read.json(Seq(jsonE).toDS).schema - - SchemaUtils.isSubset(schemaD, schemaE) should be(true) - SchemaUtils.isSubset(schemaE, schemaD) should be(false) - } - - it should "be false for A and B in both directions" in { - val schemaA = spark.read.json(Seq(jsonA).toDS).schema - val schemaB = spark.read.json(Seq(jsonA).toDS).schema - - SchemaUtils.isSubset(schemaA, schemaB) should be(true) - SchemaUtils.isSubset(schemaB, schemaA) should be(true) - } -} diff --git a/src/test/scala/za/co/absa/spark/commons/utils/SchemaUtilTest.scala b/src/test/scala/za/co/absa/spark/commons/utils/SchemaUtilTest.scala new file mode 100644 index 00000000..939b0d02 --- /dev/null +++ b/src/test/scala/za/co/absa/spark/commons/utils/SchemaUtilTest.scala @@ -0,0 +1,40 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.spark.commons.utils + +import org.apache.spark.sql.types.{ArrayType, IntegerType, StructField, StructType} +import org.scalatest.funsuite.AnyFunSuite +import org.scalatest.matchers.should.Matchers +import za.co.absa.spark.commons.utils.SchemaUtils._ + +class SchemaUtilTest extends AnyFunSuite with Matchers { + // scalastyle:off magic.number + + test("Test isCommonSubPath()") { + assert (isCommonSubPath()) + assert (isCommonSubPath("a")) + assert (isCommonSubPath("a.b.c.d.e.f", "a.b.c.d", "a.b.c", "a.b", "a")) + assert (!isCommonSubPath("a.b.c.d.e.f", "a.b.c.x", "a.b.c", "a.b", "a")) + } + + test("Test getParentPath") { + assertResult(getParentPath("a.b.c.d.e"))("a.b.c.d") + assertResult(getParentPath("a"))("") + assertResult(getParentPath("a.bcd"))("a") + assertResult(getParentPath(""))("") + } +} From 11807aef1099193dd0f2303271e9da84d75e6337 Mon Sep 17 00:00:00 2001 From: Adrian Olosutean Date: Tue, 8 Feb 2022 08:25:04 +0100 Subject: [PATCH 17/19] #22 merging --- README.md | 21 +++++++++---------- .../spark/commons/SparkVersionGuard.scala | 2 -- .../commons/implicits/DataTypeImplicits.scala | 2 +- .../implicits/StructTypeImplicits.scala | 2 +- .../implicits/DataFrameImplicitsTest.scala | 2 +- .../implicits/DataTypeImplicitsTest.scala | 2 +- .../commons/implicits/JsonTestData.scala | 2 +- .../StructTypeImplicitsArrayTest.scala | 2 +- ...maUtilTest.scala => SchemaUtilsTest.scala} | 3 +-- 9 files changed, 17 insertions(+), 21 deletions(-) rename src/test/scala/za/co/absa/spark/commons/utils/{SchemaUtilTest.scala => SchemaUtilsTest.scala} (90%) diff --git a/README.md b/README.md index 19bc959b..78e98185 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,7 @@ val myListener = new MyQueryExecutionListener with NonFatalQueryExecutionListene spark.listenerManager.register(myListener) ``` -### Spark Schema Utils +### Schema Utils > >**Note:** @@ -25,31 +25,30 @@ spark.listenerManager.register(myListener) >|Spark| 2.4 | 3.1 | 3.2 | >|Json4s| 3.5 | 3.7 | 3.7 | >|Jackson| 2.6 | 2.10 | 2.12 | -_Spark Schema Utils_ provides methods for working with schemas, its comparison and alignment. +_Schema Utils_ provides methods for working with schemas, its comparison and alignment. -1. Schema comparison returning true/false. Ignores the order of columns +1. Returns the parent path of a field. Returns an empty string if a root level field name is provided. ```scala - SchemaUtils.equivalentSchemas(other) + SchemaUtils.getParentPath(columnName) ``` -2. Schema comparison returning difference. Ignores the order of columns +2. Get paths for all array subfields of this given datatype ```scala - SchemaUtils.diff(other) + SchemaUtils.getAllArraySubPaths(other) ``` -3. Schema selector generator which provides a List of columns to be used in a -select to order and positionally filter columns of a DataFrame +3. For a given list of field paths determines if any path pair is a subset of one another. ```scala - SchemaUtils.getDataFrameSelector(schema) + SchemaUtils.isCommonSubPath(paths) ``` -4. Dataframe alignment method using the `getDataFrameSelector` method. +4. Append a new attribute to path or empty string. ```scala - SchemaUtils.alignSchema(dataFrameToBeAligned, modelSchema) + SchemaUtils.appendPath(path, fieldName) ``` ### ColumnImplicits diff --git a/src/main/scala/za/co/absa/spark/commons/SparkVersionGuard.scala b/src/main/scala/za/co/absa/spark/commons/SparkVersionGuard.scala index d78c6d54..a8e92046 100644 --- a/src/main/scala/za/co/absa/spark/commons/SparkVersionGuard.scala +++ b/src/main/scala/za/co/absa/spark/commons/SparkVersionGuard.scala @@ -16,12 +16,10 @@ package za.co.absa.spark.commons -import org.apache.spark.sql.functions.col import org.slf4j.{Logger, LoggerFactory} import za.co.absa.commons.version.Version import za.co.absa.commons.version.Version.VersionStringInterpolator import za.co.absa.commons.version.impl.SemVer20Impl.SemanticVersion -import za.co.absa.spark.commons.implicits._ object SparkVersionGuard { val minSpark2XVersionIncluded: SemanticVersion = semver"2.4.2" diff --git a/src/main/scala/za/co/absa/spark/commons/implicits/DataTypeImplicits.scala b/src/main/scala/za/co/absa/spark/commons/implicits/DataTypeImplicits.scala index b3498005..8a8e38e7 100644 --- a/src/main/scala/za/co/absa/spark/commons/implicits/DataTypeImplicits.scala +++ b/src/main/scala/za/co/absa/spark/commons/implicits/DataTypeImplicits.scala @@ -16,7 +16,7 @@ package za.co.absa.spark.commons.implicits -import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, ByteType, DataType, DateType, DecimalType, DoubleType, FloatType, IntegerType, LongType, NullType, ShortType, StringType, StructType, TimestampType} +import org.apache.spark.sql.types._ import za.co.absa.spark.commons.implicits.ArrayTypeImplicits.ArrayTypeEnhancements import za.co.absa.spark.commons.implicits.StructTypeImplicits.StructTypeEnhancements diff --git a/src/main/scala/za/co/absa/spark/commons/implicits/StructTypeImplicits.scala b/src/main/scala/za/co/absa/spark/commons/implicits/StructTypeImplicits.scala index a14142e5..3fba43e7 100644 --- a/src/main/scala/za/co/absa/spark/commons/implicits/StructTypeImplicits.scala +++ b/src/main/scala/za/co/absa/spark/commons/implicits/StructTypeImplicits.scala @@ -18,7 +18,7 @@ package za.co.absa.spark.commons.implicits import org.apache.spark.sql.Column import org.apache.spark.sql.functions.{col, struct} -import org.apache.spark.sql.types.{ArrayType, DataType, NullType, StructField, StructType} +import org.apache.spark.sql.types._ import za.co.absa.spark.commons.implicits.DataTypeImplicits.DataTypeEnhancements import za.co.absa.spark.commons.implicits.StructFieldImplicits.StructFieldEnhamcenets import za.co.absa.spark.commons.utils.SchemaUtils.{getAllArraySubPaths, isCommonSubPath, transform} diff --git a/src/test/scala/za/co/absa/spark/commons/implicits/DataFrameImplicitsTest.scala b/src/test/scala/za/co/absa/spark/commons/implicits/DataFrameImplicitsTest.scala index 412a5613..3de27353 100644 --- a/src/test/scala/za/co/absa/spark/commons/implicits/DataFrameImplicitsTest.scala +++ b/src/test/scala/za/co/absa/spark/commons/implicits/DataFrameImplicitsTest.scala @@ -16,8 +16,8 @@ package za.co.absa.spark.commons.implicits -import org.apache.spark.sql.{AnalysisException, DataFrame} import org.apache.spark.sql.functions.lit +import org.apache.spark.sql.{AnalysisException, DataFrame} import org.scalatest.funsuite.AnyFunSuite import za.co.absa.spark.commons.test.SparkTestBase diff --git a/src/test/scala/za/co/absa/spark/commons/implicits/DataTypeImplicitsTest.scala b/src/test/scala/za/co/absa/spark/commons/implicits/DataTypeImplicitsTest.scala index 445f14ce..1049667e 100644 --- a/src/test/scala/za/co/absa/spark/commons/implicits/DataTypeImplicitsTest.scala +++ b/src/test/scala/za/co/absa/spark/commons/implicits/DataTypeImplicitsTest.scala @@ -16,7 +16,7 @@ package za.co.absa.spark.commons.implicits -import org.apache.spark.sql.types.{ArrayType, ByteType, DateType, DecimalType, IntegerType, LongType, ShortType, StringType, StructType, TimestampType} +import org.apache.spark.sql.types._ import org.scalatest.funsuite.AnyFunSuite import za.co.absa.spark.commons.implicits.DataTypeImplicits.DataTypeEnhancements diff --git a/src/test/scala/za/co/absa/spark/commons/implicits/JsonTestData.scala b/src/test/scala/za/co/absa/spark/commons/implicits/JsonTestData.scala index 4a467703..941622fe 100644 --- a/src/test/scala/za/co/absa/spark/commons/implicits/JsonTestData.scala +++ b/src/test/scala/za/co/absa/spark/commons/implicits/JsonTestData.scala @@ -16,7 +16,7 @@ package za.co.absa.spark.commons.implicits -import org.apache.spark.sql.types.{ArrayType, IntegerType, StringType, StructField, StructType} +import org.apache.spark.sql.types._ trait JsonTestData { diff --git a/src/test/scala/za/co/absa/spark/commons/implicits/StructTypeImplicitsArrayTest.scala b/src/test/scala/za/co/absa/spark/commons/implicits/StructTypeImplicitsArrayTest.scala index b1bbc226..b0037594 100644 --- a/src/test/scala/za/co/absa/spark/commons/implicits/StructTypeImplicitsArrayTest.scala +++ b/src/test/scala/za/co/absa/spark/commons/implicits/StructTypeImplicitsArrayTest.scala @@ -16,7 +16,7 @@ package za.co.absa.spark.commons.implicits -import org.apache.spark.sql.types.{ArrayType, LongType, StringType, StructField, StructType} +import org.apache.spark.sql.types.{ArrayType, StringType, StructField, StructType} import org.scalatest.funsuite.AnyFunSuite import za.co.absa.spark.commons.implicits.StructTypeImplicits.StructTypeEnhancementsArrays import za.co.absa.spark.commons.test.SparkTestBase diff --git a/src/test/scala/za/co/absa/spark/commons/utils/SchemaUtilTest.scala b/src/test/scala/za/co/absa/spark/commons/utils/SchemaUtilsTest.scala similarity index 90% rename from src/test/scala/za/co/absa/spark/commons/utils/SchemaUtilTest.scala rename to src/test/scala/za/co/absa/spark/commons/utils/SchemaUtilsTest.scala index 939b0d02..299a712f 100644 --- a/src/test/scala/za/co/absa/spark/commons/utils/SchemaUtilTest.scala +++ b/src/test/scala/za/co/absa/spark/commons/utils/SchemaUtilsTest.scala @@ -16,12 +16,11 @@ package za.co.absa.spark.commons.utils -import org.apache.spark.sql.types.{ArrayType, IntegerType, StructField, StructType} import org.scalatest.funsuite.AnyFunSuite import org.scalatest.matchers.should.Matchers import za.co.absa.spark.commons.utils.SchemaUtils._ -class SchemaUtilTest extends AnyFunSuite with Matchers { +class SchemaUtilsTest extends AnyFunSuite with Matchers { // scalastyle:off magic.number test("Test isCommonSubPath()") { From 825ada47a91865e625d2afc44b51b9c1d03354ba Mon Sep 17 00:00:00 2001 From: Adrian Olosutean Date: Thu, 10 Feb 2022 16:59:37 +0100 Subject: [PATCH 18/19] #22 isOfType proper implemenetation --- .../implicits/StructTypeImplicits.scala | 12 ++++--- .../commons/implicits/JsonTestData.scala | 3 ++ .../StructTypeImplicitsArrayTest.scala | 35 ++++++++++++++++++- 3 files changed, 45 insertions(+), 5 deletions(-) diff --git a/src/main/scala/za/co/absa/spark/commons/implicits/StructTypeImplicits.scala b/src/main/scala/za/co/absa/spark/commons/implicits/StructTypeImplicits.scala index 3fba43e7..a7289b08 100644 --- a/src/main/scala/za/co/absa/spark/commons/implicits/StructTypeImplicits.scala +++ b/src/main/scala/za/co/absa/spark/commons/implicits/StructTypeImplicits.scala @@ -24,7 +24,7 @@ import za.co.absa.spark.commons.implicits.StructFieldImplicits.StructFieldEnhamc import za.co.absa.spark.commons.utils.SchemaUtils.{getAllArraySubPaths, isCommonSubPath, transform} import scala.annotation.tailrec -import scala.reflect.runtime.universe._ +import scala.reflect.ClassTag import scala.util.Try object StructTypeImplicits { @@ -252,12 +252,16 @@ object StructTypeImplicits { schema.map(field => field.name.toLowerCase() -> field).toMap } - def isOfType[T <: DataType](path: String)(implicit tag: TypeTag[T]): Boolean = { + def isOfType[T <: DataType](path: String)(implicit ev: ClassTag[T]): Boolean = { val fieldType = getFieldType(path).getOrElse(NullType) - val classT: Class[_] = runtimeMirror(getClass.getClassLoader).runtimeClass(tag.tpe.typeSymbol.asClass) - classT.equals(fieldType.getClass) + + fieldType match { + case _: T => true + case _ => false + } } + protected def evaluateConditionsForField(structField: StructType, path: Seq[String], fieldPathName: String, applyArrayHelper: Boolean, applyLeafCondition: Boolean = false, conditionLeafSh: StructType => Boolean = _ => false): Boolean = { diff --git a/src/test/scala/za/co/absa/spark/commons/implicits/JsonTestData.scala b/src/test/scala/za/co/absa/spark/commons/implicits/JsonTestData.scala index 941622fe..b7c21c34 100644 --- a/src/test/scala/za/co/absa/spark/commons/implicits/JsonTestData.scala +++ b/src/test/scala/za/co/absa/spark/commons/implicits/JsonTestData.scala @@ -35,6 +35,9 @@ trait JsonTestData { """{"id":6,"legs":[]}""" :: """{"id":7}""" :: Nil + protected val sampleA: Seq[String] = sample :+ """{"id":1,"legs":[{"legid":100,"conditions":[{"checks":[{"checkNums":["1","2","3b","4","5c","6"]}],"amount":100}]}], "key" : {"alfa": "1", "beta": {"beta2": "2"}}}""" + protected val sampleE: Seq[String] = sample :+ """{"legs":[{"legid":100,"conditions":[{"amount":100,"checks":[{"checkNums":["1","2","3b","4","5c","6"]}]}]}],"id":1, "key" : {"beta": {"beta2": 2}, "alfa": 1}, "extra" : "a"}""" + protected val schema = StructType(Seq( StructField("a", IntegerType, nullable = false), StructField("b", StructType(Seq( diff --git a/src/test/scala/za/co/absa/spark/commons/implicits/StructTypeImplicitsArrayTest.scala b/src/test/scala/za/co/absa/spark/commons/implicits/StructTypeImplicitsArrayTest.scala index b0037594..620dd8c1 100644 --- a/src/test/scala/za/co/absa/spark/commons/implicits/StructTypeImplicitsArrayTest.scala +++ b/src/test/scala/za/co/absa/spark/commons/implicits/StructTypeImplicitsArrayTest.scala @@ -16,7 +16,7 @@ package za.co.absa.spark.commons.implicits -import org.apache.spark.sql.types.{ArrayType, StringType, StructField, StructType} +import org.apache.spark.sql.types.{ArrayType, DataType, LongType, NumericType, StringType, StructField, StructType} import org.scalatest.funsuite.AnyFunSuite import za.co.absa.spark.commons.implicits.StructTypeImplicits.StructTypeEnhancementsArrays import za.co.absa.spark.commons.test.SparkTestBase @@ -26,6 +26,9 @@ class StructTypeImplicitsArrayTest extends AnyFunSuite with SparkTestBase with J import spark.implicits._ val df = spark.read.json(sample.toDS) + val dfA = spark.read.json(sampleA.toDS) + val dfE = spark.read.json(sampleE.toDS) + test("Testing getFirstArrayPath") { assertResult("f.g")(schema.getFirstArrayPath("f.g.h")) assertResult("f.g")(schema.getFirstArrayPath("f.g")) @@ -51,10 +54,40 @@ class StructTypeImplicitsArrayTest extends AnyFunSuite with SparkTestBase with J assert(df.schema.isOfType[ArrayType]("legs.conditions")) assert(df.schema.isOfType[ArrayType]("legs.conditions.checks")) assert(df.schema.isOfType[ArrayType]("legs.conditions.checks.checkNums")) + assert(!df.schema.isOfType[ArrayType]("id")) assert(!df.schema.isOfType[ArrayType]("legs.legid")) } + test("Test isOfType LongType") { + assert(dfA.schema.isOfType[LongType]("legs.legid")) + assert(dfA.schema.isOfType[LongType]("id")) + assert(!dfA.schema.isOfType[StringType]("id")) + assert(dfA.schema.isOfType[NumericType]("id")) + assert(dfA.schema.isOfType[DataType]("id")) + } + + test("Test isOfType StructType") { + assert(!dfE.schema.isOfType[StructType]("id")) + assert(dfE.schema.isOfType[StructType]("key")) + assert(!dfE.schema.isOfType[StructType]("key.alfa")) + assert(dfE.schema.isOfType[StructType]("key.beta")) + assert(!dfE.schema.isOfType[StructType]("key.beta.beta2")) + assert(!dfE.schema.isOfType[StructType]("extra")) + } + + test("Test isOfType StringType") { + assert(!dfA.schema.isOfType[StringType]("id")) + assert(!dfA.schema.isOfType[StringType]("key")) + assert(dfA.schema.isOfType[StringType]("key.alfa")) + assert(!dfA.schema.isOfType[StringType]("key.beta")) + assert(dfA.schema.isOfType[StringType]("key.beta.beta2")) + assert(!dfA.schema.isOfType[StringType]("extra")) + + assert(!dfE.schema.isOfType[StringType]("key.alfa")) + assert(dfE.schema.isOfType[StringType]("extra")) + } + test("Test getDeepestCommonArrayPath() for a path without an array") { val schema = StructType(Seq[StructField]( StructField("a", From 1fa973cc9ac1b549f0baabedcb01588a18bce8cc Mon Sep 17 00:00:00 2001 From: Adrian Olosutean Date: Thu, 10 Feb 2022 21:50:56 +0100 Subject: [PATCH 19/19] #22 fix doc --- README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 7ed6ec8a..96f61d02 100644 --- a/README.md +++ b/README.md @@ -141,7 +141,7 @@ Of them, metadata methods are: _ArrayTypeImplicits_ provides implicit methods for working with ArrayType objects. -1. Get a field from a text path +1. Checks if the arraytype is equivalent to another ```scala arrayType.isEquivalentArrayType(otherArrayType) @@ -158,7 +158,7 @@ _ArrayTypeImplicits_ provides implicit methods for working with ArrayType object _DataTypeImplicits_ provides implicit methods for working with DataType objects. -1. Get a field from a text path +1. Checks if the datatype is equivalent to another ```scala dataType.isEquivalentDataType(otherDt) @@ -225,7 +225,7 @@ _StructTypeImplicits_ provides implicit methods for working with StructType obje ```scala structType.isOnlyField(columnName) ``` -9. Checks if 2 schemas are equivalent +9. Checks if 2 structtypes are equivalent ```scala structType.isEquivalent(other)