diff --git a/README.md b/README.md index f057717e..96f61d02 100644 --- a/README.md +++ b/README.md @@ -43,7 +43,7 @@ val myListener = new MyQueryExecutionListener with NonFatalQueryExecutionListene spark.listenerManager.register(myListener) ``` -### Spark Schema Utils +### Schema Utils > >**Note:** @@ -55,31 +55,30 @@ spark.listenerManager.register(myListener) >|Spark| 2.4 | 3.1 | 3.2 | >|Json4s| 3.5 | 3.7 | 3.7 | >|Jackson| 2.6 | 2.10 | 2.12 | -_Spark Schema Utils_ provides methods for working with schemas, its comparison and alignment. +_Schema Utils_ provides methods for working with schemas, its comparison and alignment. -1. Schema comparison returning true/false. Ignores the order of columns +1. Returns the parent path of a field. Returns an empty string if a root level field name is provided. ```scala - SchemaUtils.equivalentSchemas(schema1, schema2) + SchemaUtils.getParentPath(columnName) ``` -2. Schema comparison returning difference. Ignores the order of columns +2. Get paths for all array subfields of this given datatype ```scala - SchemaUtils.diff(schema1, schema2) + SchemaUtils.getAllArraySubPaths(other) ``` -3. Schema selector generator which provides a List of columns to be used in a -select to order and positionally filter columns of a DataFrame +3. For a given list of field paths determines if any path pair is a subset of one another. ```scala - SchemaUtils.getDataFrameSelector(schema) + SchemaUtils.isCommonSubPath(paths) ``` -4. Dataframe alignment method using the `getDataFrameSelector` method. +4. Append a new attribute to path or empty string. ```scala - SchemaUtils.alignSchema(dataFrameToBeAligned, modelSchema) + SchemaUtils.appendPath(path, fieldName) ``` ### ColumnImplicits @@ -97,11 +96,11 @@ _ColumnImplicits_ provide implicit methods for transforming Spark Columns ```scala column.zeroBasedSubstr(startPos) ``` + 3. Returns column with requested substring. It shifts the substring indexation to be in accordance with Scala/ Java. If the provided starting position where to start the substring from is negative, it will be counted from end. The length of the desired substring, if longer then the rest of the string, all the remaining characters are taken. - ```scala column.zeroBasedSubstr(startPos, length) ``` @@ -109,6 +108,7 @@ _ColumnImplicits_ provide implicit methods for transforming Spark Columns ### StructFieldImplicits _StructFieldImplicits_ provides implicit methods for working with StructField objects. + Of them, metadata methods are: 1. Gets the metadata Option[String] value given a key @@ -135,6 +135,156 @@ Of them, metadata methods are: ```scala structField.metadata.hasKey(key) ``` + +### ArrayTypeImplicits + +_ArrayTypeImplicits_ provides implicit methods for working with ArrayType objects. + + +1. Checks if the arraytype is equivalent to another + + ```scala + arrayType.isEquivalentArrayType(otherArrayType) + ``` + +2. For an array of arrays, get the final element type at the bottom of the array + + ```scala + arrayType.getDeepestArrayType() + ``` + +### DataTypeImplicits + +_DataTypeImplicits_ provides implicit methods for working with DataType objects. + + +1. Checks if the datatype is equivalent to another + + ```scala + dataType.isEquivalentDataType(otherDt) + ``` + +2. Checks if a casting between types always succeeds + + ```scala + dataType.doesCastAlwaysSucceed(otherDt) + ``` +3. Checks if type is primitive + + ```scala + dataType.isPrimitive() + ``` + +### StructTypeImplicits + +_StructTypeImplicits_ provides implicit methods for working with StructType objects. + + +1. Get a field from a text path + + ```scala + structType.getField(path) + ``` +2. Get a type of a field from a text path + + ```scala + structType.getFieldType(path) + ``` +3. Checks if the specified path is an array of structs + + ```scala + structType.isColumnArrayOfStruct(path) + ``` + +4. Get nullability of a field from a text path + + ```scala + structType.getFieldNullability(path) + ``` + +5. Checks if a field specified by a path exists + + ```scala + structType.fieldExists(path) + ``` + +6. Get paths for all array fields in the schema + + ```scala + structType.getAllArrayPaths() + ``` + +7. Get a closest unique column name + + ```scala + structType.getClosestUniqueName(desiredName) + ``` + +8. Checks if a field is the only field in a struct + + ```scala + structType.isOnlyField(columnName) + ``` +9. Checks if 2 structtypes are equivalent + + ```scala + structType.isEquivalent(other) + ``` + +10. Returns a list of differences in one utils to the other + + ```scala + structType.diffSchema(otherSchema, parent) + ``` + +11. Checks if a field is of the specified type + + ```scala + structType.isOfType[ArrayType](path) + ``` +12. Checks if a field is a subset of the specified type + + ```scala + structType.isSubset(other) + ``` + +13. Returns data selector that can be used to align utils of a data frame. + + ```scala + structType.getDataFrameSelector() + ``` + +###StructTypeArrayImplicits + +1. Get first array column's path out of complete path + + ```scala + structType.getFirstArrayPath(path) + ``` + +2. Get all array columns' paths out of complete path. + + ```scala + structType.getAllArraysInPath(path) + ``` + +3. For a given list of field paths determines the deepest common array path + + ```scala + structType.getDeepestCommonArrayPath(fieldPaths) + ``` + +4. For a field path determines the deepest array path + + ```scala + structType.getDeepestArrayPath(path) + ``` + +5. Checks if a field is an array that is not nested in another array + + ```scala + structType.isNonNestedArray(path) + ``` # Spark Version Guard @@ -176,4 +326,15 @@ _DataFrameImplicits_ provides methods for transformations on Dataframes ```scala df.withColumnIfDoesNotExist((df: DataFrame, _) => df)(colName, colExpression) + ``` + +3. Aligns the utils of a DataFrame to the selector for operations + where utils order might be important (e.g. hashing the whole rows and using except) + + ```scala + df.alignSchema(structType) + ``` + + ```scala + df.alignSchema(listColumns) ``` \ No newline at end of file diff --git a/src/main/scala/za/co/absa/spark/commons/implicits/ArrayTypeImplicits.scala b/src/main/scala/za/co/absa/spark/commons/implicits/ArrayTypeImplicits.scala new file mode 100644 index 00000000..8c552f40 --- /dev/null +++ b/src/main/scala/za/co/absa/spark/commons/implicits/ArrayTypeImplicits.scala @@ -0,0 +1,88 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.spark.commons.implicits + +import org.apache.spark.sql.types.{ArrayType, DataType, StructType} +import za.co.absa.spark.commons.implicits.StructTypeImplicits.StructTypeEnhancements + +import scala.annotation.tailrec + +object ArrayTypeImplicits { + + implicit class ArrayTypeEnhancements(arrayType: ArrayType) { + + /** + * Compares 2 array fields of a dataframe utils. + * + * @param other The second array to compare + * @return true if provided arrays are the same ignoring nullability + */ + @scala.annotation.tailrec + final def isEquivalentArrayType(other: ArrayType): Boolean = { + arrayType.elementType match { + case arrayType1: ArrayType => + other.elementType match { + case arrayType2: ArrayType => arrayType1.isEquivalentArrayType(arrayType2) + case _ => false + } + case structType1: StructType => + other.elementType match { + case structType2: StructType => structType1.isEquivalent(structType2) + case _ => false + } + case _ => arrayType.elementType == other.elementType + } + } + + + /** + * Finds all differences of two ArrayTypes and returns their paths + * + * @param array2 The second array to compare + * @param parent Parent path. This is used for the accumulation of differences and their print out + * @return Returns a Seq of found difference paths in scheme in the Array + */ + @scala.annotation.tailrec + private[implicits] final def diffArray(array2: ArrayType, parent: String): Seq[String] = { + arrayType.elementType match { + case _ if arrayType.elementType.typeName != array2.elementType.typeName => + Seq(s"$parent data type doesn't match (${arrayType.elementType.typeName}) vs (${array2.elementType.typeName})") + case arrayType1: ArrayType => + arrayType1.diffArray(array2.elementType.asInstanceOf[ArrayType], s"$parent") + case structType1: StructType => + structType1.diffSchema(array2.elementType.asInstanceOf[StructType], s"$parent") + case _ => Seq.empty[String] + } + } + + /** + * For an array of arrays of arrays, ... get the final element type at the bottom of the array + * + * @return A non-array data type at the bottom of array nesting + */ + final def getDeepestArrayType(): Unit = { + @tailrec + def getDeepestArrayTypeHelper(arrayType: ArrayType): DataType = { + arrayType.elementType match { + case a: ArrayType => getDeepestArrayTypeHelper(a) + case b => b + } + } + getDeepestArrayTypeHelper(arrayType) + } + } +} diff --git a/src/main/scala/za/co/absa/spark/commons/implicits/DataFrameImplicits.scala b/src/main/scala/za/co/absa/spark/commons/implicits/DataFrameImplicits.scala index 4905b1d6..723a2e51 100644 --- a/src/main/scala/za/co/absa/spark/commons/implicits/DataFrameImplicits.scala +++ b/src/main/scala/za/co/absa/spark/commons/implicits/DataFrameImplicits.scala @@ -18,7 +18,9 @@ package za.co.absa.spark.commons.implicits import java.io.ByteArrayOutputStream +import org.apache.spark.sql.types.StructType import org.apache.spark.sql.{Column, DataFrame} +import za.co.absa.spark.commons.implicits.StructTypeImplicits.StructTypeEnhancements object DataFrameImplicits { @@ -73,6 +75,24 @@ object DataFrameImplicits { df.withColumn(colName, colExpr) } } + + /** + * Using utils selector returned from [[StructTypeEnhancements.getDataFrameSelector]] aligns the utils of a DataFrame to the selector + * for operations where utils order might be important (e.g. hashing the whole rows and using except) + * + * @param selector model structType for the alignment of df + * @return Returns aligned and filtered utils + */ + def alignSchema(selector: List[Column]): DataFrame = df.select(selector: _*) + + /** + * Using utils selector from [[getDataFrameSelector]] aligns the utils of a DataFrame to the selector for operations + * where utils order might be important (e.g. hashing the whole rows and using except) + * + * @param structType model structType for the alignment of df + * @return Returns aligned and filtered utils + */ + def alignSchema(structType: StructType): DataFrame = alignSchema(structType.getDataFrameSelector()) } } diff --git a/src/main/scala/za/co/absa/spark/commons/implicits/DataTypeImplicits.scala b/src/main/scala/za/co/absa/spark/commons/implicits/DataTypeImplicits.scala new file mode 100644 index 00000000..8a8e38e7 --- /dev/null +++ b/src/main/scala/za/co/absa/spark/commons/implicits/DataTypeImplicits.scala @@ -0,0 +1,89 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.spark.commons.implicits + +import org.apache.spark.sql.types._ +import za.co.absa.spark.commons.implicits.ArrayTypeImplicits.ArrayTypeEnhancements +import za.co.absa.spark.commons.implicits.StructTypeImplicits.StructTypeEnhancements + +object DataTypeImplicits { + + implicit class DataTypeEnhancements(dt: DataType) { + /** + * Compares 2 fields of a dataframe utils. + * + * @param other The second field to compare + * @return true if provided fields are the same ignoring nullability + */ + def isEquivalentDataType(other: DataType): Boolean = { + dt match { + case arrayType1: ArrayType => + other match { + case arrayType2: ArrayType => arrayType1.isEquivalentArrayType(arrayType2) + case _ => false + } + case structType1: StructType => + other match { + case structType2: StructType => structType1.isEquivalent(structType2) + case _ => false + } + case _ => dt == other + } + } + + /** + * Checks if a casting between types always succeeds + * + * @param targetType A type to be casted to + * @return true if casting never fails + */ + def doesCastAlwaysSucceed(targetType: DataType): Boolean = { + (dt, targetType) match { + case (_: StructType, _) | (_: ArrayType, _) => false + case (a, b) if a == b => true + case (_, _: StringType) => true + case (_: ByteType, _: ShortType | _: IntegerType | _: LongType) => true + case (_: ShortType, _: IntegerType | _: LongType) => true + case (_: IntegerType, _: LongType) => true + case (_: DateType, _: TimestampType) => true + case _ => false + } + } + + /** + * Determine if a datatype is a primitive one + */ + def isPrimitive(): Boolean = dt match { + case _: BinaryType + | _: BooleanType + | _: ByteType + | _: DateType + | _: DecimalType + | _: DoubleType + | _: FloatType + | _: IntegerType + | _: LongType + | _: NullType + | _: ShortType + | _: StringType + | _: TimestampType => true + case _ => false + } + } + + +} diff --git a/src/main/scala/za/co/absa/spark/commons/implicits/StructFieldImplicits.scala b/src/main/scala/za/co/absa/spark/commons/implicits/StructFieldImplicits.scala index fdcb601e..24d2fdf5 100644 --- a/src/main/scala/za/co/absa/spark/commons/implicits/StructFieldImplicits.scala +++ b/src/main/scala/za/co/absa/spark/commons/implicits/StructFieldImplicits.scala @@ -17,6 +17,9 @@ package za.co.absa.spark.commons.implicits import org.apache.spark.sql.types._ +import za.co.absa.spark.commons.implicits.ArrayTypeImplicits.ArrayTypeEnhancements +import za.co.absa.spark.commons.implicits.StructTypeImplicits.StructTypeEnhancements + import scala.util.Try object StructFieldImplicits { @@ -44,4 +47,27 @@ object StructFieldImplicits { metadata.contains(key) } } + + implicit class StructFieldEnhamcenets(val structField: StructField) { + + /** + * Finds all differences of two StructFields and returns their paths + * + * @param other The second field to compare + * @param parent Parent path. This is used for the accumulation of differences and their print out + * @return Returns a Seq of found difference paths in scheme in the StructField + */ + private[implicits] def diffField(other: StructField, parent: String): Seq[String] = { + structField.dataType match { + case _ if structField.dataType.typeName != other.dataType.typeName => + Seq(s"$parent.${structField.name} data type doesn't match (${structField.dataType.typeName}) vs (${other.dataType.typeName})") + case arrayType1: ArrayType => + arrayType1.diffArray(other.dataType.asInstanceOf[ArrayType], s"$parent.${structField.name}") + case structType1: StructType => + structType1.diffSchema(other.dataType.asInstanceOf[StructType], s"$parent.${structField.name}") + case _ => + Seq.empty[String] + } + } + } } diff --git a/src/main/scala/za/co/absa/spark/commons/implicits/StructTypeImplicits.scala b/src/main/scala/za/co/absa/spark/commons/implicits/StructTypeImplicits.scala new file mode 100644 index 00000000..a7289b08 --- /dev/null +++ b/src/main/scala/za/co/absa/spark/commons/implicits/StructTypeImplicits.scala @@ -0,0 +1,412 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.spark.commons.implicits + +import org.apache.spark.sql.Column +import org.apache.spark.sql.functions.{col, struct} +import org.apache.spark.sql.types._ +import za.co.absa.spark.commons.implicits.DataTypeImplicits.DataTypeEnhancements +import za.co.absa.spark.commons.implicits.StructFieldImplicits.StructFieldEnhamcenets +import za.co.absa.spark.commons.utils.SchemaUtils.{getAllArraySubPaths, isCommonSubPath, transform} + +import scala.annotation.tailrec +import scala.reflect.ClassTag +import scala.util.Try + +object StructTypeImplicits { + implicit class StructTypeEnhancements(val schema: StructType) { + /** + * Get a field from a text path and a given utils + * @param path The dot-separated path to the field + * @return Some(the requested field) or None if the field does not exist + */ + def getField(path: String): Option[StructField] = { + + @tailrec + def goThroughArrayDataType(dataType: DataType): DataType = { + dataType match { + case ArrayType(dt, _) => goThroughArrayDataType(dt) + case result => result + } + } + + @tailrec + def examineStructField(names: List[String], structField: StructField): Option[StructField] = { + if (names.isEmpty) { + Option(structField) + } else { + structField.dataType match { + case struct: StructType => examineStructField(names.tail, struct(names.head)) + case ArrayType(el: DataType, _) => + goThroughArrayDataType(el) match { + case struct: StructType => examineStructField(names.tail, struct(names.head)) + case _ => None + } + case _ => None + } + } + } + + val pathTokens = path.split('.').toList + Try{ + examineStructField(pathTokens.tail, schema(pathTokens.head)) + }.getOrElse(None) + } + + /** + * Get a type of a field from a text path and a given utils + * + * @param path The dot-separated path to the field + * @return Some(the type of the field) or None if the field does not exist + */ + def getFieldType(path: String): Option[DataType] = { + getField(path).map(_.dataType) + } + + /** + * Checks if the specified path is an array of structs + * + * @param path The dot-separated path to the field + * @return true if the field is an array of structs + */ + def isColumnArrayOfStruct(path: String): Boolean = { + getFieldType(path) match { + case Some(dt) => + dt match { + case arrayType: ArrayType => + arrayType.elementType match { + case _: StructType => true + case _ => false + } + case _ => false + } + case None => false + } + } + + /** + * Get nullability of a field from a text path and a given utils + * + * @param path The dot-separated path to the field + * @return Some(nullable) or None if the field does not exist + */ + def getFieldNullability(path: String): Option[Boolean] = { + getField(path).map(_.nullable) + } + + /** + * Checks if a field specified by a path and a utils exists + * @param path The dot-separated path to the field + * @return True if the field exists false otherwise + */ + def fieldExists(path: String): Boolean = { + getField(path).nonEmpty + } + + /** + * Get paths for all array fields in the utils + * + * @return Seq of dot separated paths of fields in the utils, which are of type Array + */ + def getAllArrayPaths(): Seq[String] = { + schema.fields.flatMap(f => getAllArraySubPaths("", f.name, f.dataType)).toSeq + } + + /** + * Returns data selector that can be used to align utils of a data frame. You can use [[alignSchema]]. + * + * @return Sorted DF to conform to utils + */ + def getDataFrameSelector(): List[Column] = { + + def processArray(arrType: ArrayType, column: Column, name: String): Column = { + arrType.elementType match { + case arrType: ArrayType => + transform(column, x => processArray(arrType, x, name)).as(name) + case nestedStructType: StructType => + transform(column, x => struct(processStruct(nestedStructType, Some(x)): _*)).as(name) + case _ => column + } + } + + def processStruct(curSchema: StructType, parent: Option[Column]): List[Column] = { + curSchema.foldRight(List.empty[Column])((field, acc) => { + val currentCol: Column = parent match { + case Some(x) => x.getField(field.name).as(field.name) + case None => col(field.name) + } + field.dataType match { + case arrType: ArrayType => processArray(arrType, currentCol, field.name) :: acc + case structType: StructType => struct(processStruct(structType, Some(currentCol)): _*).as(field.name) :: acc + case _ => currentCol :: acc + } + }) + } + + processStruct(schema, None) + } + + /** + * Get a closest unique column name + * + * @param desiredName A prefix to use for the column name + * @return A name that can be used as a unique column name + */ + def getClosestUniqueName(desiredName: String): String = { + def fieldExists(name: String): Boolean = schema.fields.exists(_.name.compareToIgnoreCase(name) == 0) + + if (fieldExists(desiredName)) { + Iterator.from(1) + .map(index => s"${desiredName}_${index}") + .dropWhile(fieldExists).next() + } else { + desiredName + } + } + + /** + * Checks if a field is the only field in a struct + * + * @param path A column to check + * @return true if the column is the only column in a struct + */ + def isOnlyField(path: String): Boolean = { + val pathSegments = path.split('.') + evaluateConditionsForField(schema, pathSegments, path, applyArrayHelper = false, applyLeafCondition = true, + field => field.fields.length == 1) + } + + /** + * Compares 2 dataframe schemas. + * + * @param other The second utils to compare + * @return true if provided schemas are the same ignoring nullability + */ + def isEquivalent(other: StructType): Boolean = { + val currentfields = schema.sortBy(_.name.toLowerCase) + val fields2 = other.sortBy(_.name.toLowerCase) + + currentfields.size == fields2.size && + currentfields.zip(fields2).forall { + case (f1, f2) => + f1.name.equalsIgnoreCase(f2.name) && + f1.dataType.isEquivalentDataType( f2.dataType) + } + } + + /** + * Returns a list of differences in one utils to the other + * + * @param other The second utils to compare + * @param parent Parent path. Should be left default by the users first run. This is used for the accumulation of + * differences and their print out. + * @return Returns a Seq of paths to differences in schemas + */ + def diffSchema(other: StructType, parent: String = ""): Seq[String] = { + val fields1 = getMapOfFields(schema) + val fields2 = getMapOfFields(other) + + val diff = fields1.values.foldLeft(Seq.empty[String])((difference, field1) => { + val field1NameLc = field1.name.toLowerCase() + if (fields2.contains(field1NameLc)) { + val field2 = fields2(field1NameLc) + difference ++ field1.diffField(field2, parent) + } else { + difference ++ Seq(s"$parent.${field1.name} cannot be found in both schemas") + } + }) + + diff.map(_.stripPrefix(".")) + } + + /** + * Checks if the originalSchema is a subset of subsetSchema. + * + * @param originalSchema The utils that needs to have at least all t + * @return true if provided schemas are the same ignoring nullability + */ + def isSubset(originalSchema: StructType): Boolean = { + val subsetFields = getMapOfFields(schema) + val originalFields = getMapOfFields(originalSchema) + + subsetFields.forall(subsetField => + originalFields.contains(subsetField._1) && + subsetField._2.dataType.isEquivalentDataType(originalFields(subsetField._1).dataType)) + } + + private def getMapOfFields(schema: StructType): Map[String, StructField] = { + schema.map(field => field.name.toLowerCase() -> field).toMap + } + + def isOfType[T <: DataType](path: String)(implicit ev: ClassTag[T]): Boolean = { + val fieldType = getFieldType(path).getOrElse(NullType) + + fieldType match { + case _: T => true + case _ => false + } + } + + + protected def evaluateConditionsForField(structField: StructType, path: Seq[String], fieldPathName: String, + applyArrayHelper: Boolean, applyLeafCondition: Boolean = false, + conditionLeafSh: StructType => Boolean = _ => false): Boolean = { + val currentField = path.head + val isLeaf = path.lengthCompare(1) <= 0 + + @tailrec + def arrayHelper(fieldPathName: String, arrayField: ArrayType, path: Seq[String]): Boolean = { + arrayField.elementType match { + case st: StructType => + evaluateConditionsForField(st, path.tail, fieldPathName, applyArrayHelper, applyLeafCondition, conditionLeafSh) + case ar: ArrayType => arrayHelper(fieldPathName, ar, path) + case _ => + if (!isLeaf) { + throw new IllegalArgumentException( + s"Primitive fields cannot have child fields $currentField is a primitive in $fieldPathName") + } + false + } + } + + structField.fields.exists(field => + if (field.name == currentField) { + (field.dataType, isLeaf) match { + case (st: StructType, false) => + evaluateConditionsForField(st, path.tail, fieldPathName, applyArrayHelper, applyLeafCondition, conditionLeafSh) + case (_, true) if applyLeafCondition => conditionLeafSh(structField) + case (_: ArrayType, true) => true + case (ar: ArrayType, false) if applyArrayHelper => arrayHelper(fieldPathName, ar, path) + case (_: ArrayType, false) => false + case (_, false) => throw new IllegalArgumentException(s"Primitive fields cannot have child fields $currentField is a primitive in $fieldPathName") + case (_, _) => false + } + } else false + ) + } + } + + implicit class StructTypeEnhancementsArrays(schema: StructType) extends StructTypeEnhancements(schema) { + /** + * Get first array column's path out of complete path. + * + * E.g if the path argument is "a.b.c.d.e" where b and d are arrays, "a.b" will be returned. + * + * @param path The path to the attribute + * @return The path of the first array field or "" if none were found + */ + def getFirstArrayPath(path: String): String = { + @tailrec + def helper(remPath: Seq[String], pathAcc: Seq[String]): Seq[String] = { + if (remPath.isEmpty) Seq() else { + val currPath = (pathAcc :+ remPath.head).mkString(".") + val currType = getFieldType(currPath) + currType match { + case Some(_: ArrayType) => pathAcc :+ remPath.head + case Some(_) => helper(remPath.tail, pathAcc :+ remPath.head) + case None => Seq() + } + } + } + + val pathToks = path.split('.') + helper(pathToks, Seq()).mkString(".") + } + + /** + * Get all array columns' paths out of complete path. + * + * E.g. if the path argument is "a.b.c.d.e" where b and d are arrays, "a.b" and "a.b.c.d" will be returned. + * + * @param path The path to the attribute + * @return Seq of dot-separated paths for all array fields in the provided path + */ + def getAllArraysInPath(path: String): Seq[String] = { + @tailrec + def helper(remPath: Seq[String], pathAcc: Seq[String], arrayAcc: Seq[String]): Seq[String] = { + if (remPath.isEmpty) arrayAcc else { + val currPath = (pathAcc :+ remPath.head).mkString(".") + val currType = getFieldType(currPath) + currType match { + case Some(_: ArrayType) => + val strings = pathAcc :+ remPath.head + helper(remPath.tail, strings, arrayAcc :+ strings.mkString(".")) + case Some(_) => helper(remPath.tail, pathAcc :+ remPath.head, arrayAcc) + case None => arrayAcc + } + } + } + + val pathToks = path.split("\\.") + helper(pathToks, Seq(), Seq()) + } + + /** + * For a given list of field paths determines the deepest common array path. + * + * For instance, if given 'a.b', 'a.b.c', 'a.b.c.d' where b and c are arrays the common deepest array + * path is 'a.b.c'. + * + * If any of the arrays are on diverging paths this function returns None. + * + * The purpose of the function is to determine the order of explosions to be made before the dataframe can be + * joined on a field inside an array. + * + * @param paths A list of paths to analyze + * @return Returns a common array path if there is one and None if any of the arrays are on diverging paths + */ + def getDeepestCommonArrayPath(paths: Seq[String]): Option[String] = { + val arrayPaths = paths.flatMap(path => getAllArraysInPath(path)).distinct + + if (arrayPaths.nonEmpty && isCommonSubPath(arrayPaths: _*)) { + Some(arrayPaths.maxBy(_.length)) + } else { + None + } + } + + /** + * For a field path determines the deepest array path. + * + * For instance, if given 'a.b.c.d' where b and c are arrays the deepest array is 'a.b.c'. + * + * @param path A path to analyze + * @return Returns a common array path if there is one and None if any of the arrays are on diverging paths + */ + def getDeepestArrayPath(path: String): Option[String] = { + val arrayPaths = getAllArraysInPath(path) + + if (arrayPaths.nonEmpty) { + Some(arrayPaths.maxBy(_.length)) + } else { + None + } + } + + /** + * Checks if a field is an array that is not nested in another array + * + * @param path A field to check + * @return true if a field is an array that is not nested in another array + */ + def isNonNestedArray(path: String): Boolean = { + val pathSegments = path.split('.') + evaluateConditionsForField(schema, pathSegments, path, applyArrayHelper = false) + } + } + +} diff --git a/src/main/scala/za/co/absa/spark/commons/schema/SchemaUtils.scala b/src/main/scala/za/co/absa/spark/commons/schema/SchemaUtils.scala deleted file mode 100644 index 9de12d61..00000000 --- a/src/main/scala/za/co/absa/spark/commons/schema/SchemaUtils.scala +++ /dev/null @@ -1,235 +0,0 @@ -/* - * Copyright 2021 ABSA Group Limited - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package za.co.absa.spark.commons.schema - -import org.apache.spark.sql.functions.{col, struct} -import org.apache.spark.sql.types.{ArrayType, DataType, StructField, StructType} -import org.apache.spark.sql.{Column, DataFrame} -import za.co.absa.spark.commons.adapters.HofsAdapter - -import scala.util.Random - -object SchemaUtils extends HofsAdapter { - /** - * Compares 2 array fields of a dataframe schema. - * - * @param array1 The first array to compare - * @param array2 The second array to compare - * @return true if provided arrays are the same ignoring nullability - */ - @scala.annotation.tailrec - private def equalArrayTypes(array1: ArrayType, array2: ArrayType): Boolean = { - array1.elementType match { - case arrayType1: ArrayType => - array2.elementType match { - case arrayType2: ArrayType => equalArrayTypes(arrayType1, arrayType2) - case _ => false - } - case structType1: StructType => - array2.elementType match { - case structType2: StructType => equivalentSchemas(structType1, structType2) - case _ => false - } - case _ => array1.elementType == array2.elementType - } - } - - /** - * Finds all differences of two ArrayTypes and returns their paths - * - * @param array1 The first array to compare - * @param array2 The second array to compare - * @param parent Parent path. This is used for the accumulation of differences and their print out - * @return Returns a Seq of found difference paths in scheme in the Array - */ - @scala.annotation.tailrec - private def diffArray(array1: ArrayType, array2: ArrayType, parent: String): Seq[String] = { - array1.elementType match { - case _ if array1.elementType.typeName != array2.elementType.typeName => - Seq(s"$parent data type doesn't match (${array1.elementType.typeName}) vs (${array2.elementType.typeName})") - case arrayType1: ArrayType => - diffArray(arrayType1, array2.elementType.asInstanceOf[ArrayType], s"$parent") - case structType1: StructType => - diffSchema(structType1, array2.elementType.asInstanceOf[StructType], s"$parent") - case _ => Seq.empty[String] - } - } - - /** - * Compares 2 fields of a dataframe schema. - * - * @param type1 The first field to compare - * @param type2 The second field to compare - * @return true if provided fields are the same ignoring nullability - */ - private def equivalentTypes(type1: DataType, type2: DataType): Boolean = { - type1 match { - case arrayType1: ArrayType => - type2 match { - case arrayType2: ArrayType => equalArrayTypes(arrayType1, arrayType2) - case _ => false - } - case structType1: StructType => - type2 match { - case structType2: StructType => equivalentSchemas(structType1, structType2) - case _ => false - } - case _ => type1 == type2 - } - } - - /** - * Finds all differences of two StructFields and returns their paths - * - * @param field1 The first field to compare - * @param field2 The second field to compare - * @param parent Parent path. This is used for the accumulation of differences and their print out - * @return Returns a Seq of found difference paths in scheme in the StructField - */ - private def diffField(field1: StructField, field2: StructField, parent: String): Seq[String] = { - field1.dataType match { - case _ if field1.dataType.typeName != field2.dataType.typeName => - Seq(s"$parent.${field1.name} data type doesn't match (${field1.dataType.typeName}) vs (${field2.dataType.typeName})") - case arrayType1: ArrayType => - diffArray(arrayType1, field2.dataType.asInstanceOf[ArrayType], s"$parent.${field1.name}") - case structType1: StructType => - diffSchema(structType1, field2.dataType.asInstanceOf[StructType], s"$parent.${field1.name}") - case _ => - Seq.empty[String] - } - } - - /** - * Returns data selector that can be used to align schema of a data frame. You can use [[alignSchema]]. - * - * @param schema Schema that serves as the model of column order - * @return Sorted DF to conform to schema - */ - def getDataFrameSelector(schema: StructType): List[Column] = { - - def processArray(arrType: ArrayType, column: Column, name: String): Column = { - arrType.elementType match { - case arrType: ArrayType => - transform(column, x => processArray(arrType, x, name)).as(name) - case nestedStructType: StructType => - transform(column, x => struct(processStruct(nestedStructType, Some(x)): _*)).as(name) - case _ => column - } - } - - def processStruct(curSchema: StructType, parent: Option[Column]): List[Column] = { - curSchema.foldRight(List.empty[Column])((field, acc) => { - val currentCol: Column = parent match { - case Some(x) => x.getField(field.name).as(field.name) - case None => col(field.name) - } - field.dataType match { - case arrType: ArrayType => processArray(arrType, currentCol, field.name) :: acc - case structType: StructType => struct(processStruct(structType, Some(currentCol)): _*).as(field.name) :: acc - case _ => currentCol :: acc - } - }) - } - - processStruct(schema, None) - } - - /** - * Using schema selector from [[getDataFrameSelector]] aligns the schema of a DataFrame to the selector for operations - * where schema order might be important (e.g. hashing the whole rows and using except) - * - * @param df DataFrame to have it's schema aligned/sorted - * @param structType model structType for the alignment of df - * @return Returns aligned and filtered schema - */ - def alignSchema(df: DataFrame, structType: StructType): DataFrame = df.select(getDataFrameSelector(structType): _*) - - /** - * Using schema selector returned from [[getDataFrameSelector]] aligns the schema of a DataFrame to the selector - * for operations where schema order might be important (e.g. hashing the whole rows and using except) - * - * @param df DataFrame to have it's schema aligned/sorted - * @param selector model structType for the alignment of df - * @return Returns aligned and filtered schema - */ - def alignSchema(df: DataFrame, selector: List[Column]): DataFrame = df.select(selector: _*) - - /** - * Compares 2 dataframe schemas. - * - * @param schema1 The first schema to compare - * @param schema2 The second schema to compare - * @return true if provided schemas are the same ignoring nullability - */ - def equivalentSchemas(schema1: StructType, schema2: StructType): Boolean = { - val fields1 = schema1.sortBy(_.name.toLowerCase) - val fields2 = schema2.sortBy(_.name.toLowerCase) - - fields1.size == fields2.size && - fields1.zip(fields2).forall { - case (f1, f2) => - f1.name.equalsIgnoreCase(f2.name) && - equivalentTypes(f1.dataType, f2.dataType) - } - } - - /** - * Returns a list of differences in one schema to the other - * - * @param schema1 The first schema to compare - * @param schema2 The second schema to compare - * @param parent Parent path. Should be left default by the users first run. This is used for the accumulation of - * differences and their print out. - * @return Returns a Seq of paths to differences in schemas - */ - def diffSchema(schema1: StructType, schema2: StructType, parent: String = ""): Seq[String] = { - val fields1 = getMapOfFields(schema1) - val fields2 = getMapOfFields(schema2) - - val diff = fields1.values.foldLeft(Seq.empty[String])((difference, field1) => { - val field1NameLc = field1.name.toLowerCase() - if (fields2.contains(field1NameLc)) { - val field2 = fields2(field1NameLc) - difference ++ diffField(field1, field2, parent) - } else { - difference ++ Seq(s"$parent.${field1.name} cannot be found in both schemas") - } - }) - - diff.map(_.stripPrefix(".")) - } - - /** - * Checks if the originalSchema is a subset of subsetSchema. - * - * @param subsetSchema The schema that needs to be extracted - * @param originalSchema The schema that needs to have at least all t - * @return true if provided schemas are the same ignoring nullability - */ - def isSubset(subsetSchema: StructType, originalSchema: StructType): Boolean = { - val subsetFields = getMapOfFields(subsetSchema) - val originalFields = getMapOfFields(originalSchema) - - subsetFields.forall(subsetField => - originalFields.contains(subsetField._1) && - equivalentTypes(subsetField._2.dataType, originalFields(subsetField._1).dataType)) - } - - private def getMapOfFields(schema: StructType): Map[String, StructField] = { - schema.map(field => field.name.toLowerCase() -> field).toMap - } -} diff --git a/src/main/scala/za/co/absa/spark/commons/utils/SchemaUtils.scala b/src/main/scala/za/co/absa/spark/commons/utils/SchemaUtils.scala new file mode 100644 index 00000000..fcf1e8d6 --- /dev/null +++ b/src/main/scala/za/co/absa/spark/commons/utils/SchemaUtils.scala @@ -0,0 +1,98 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.spark.commons.utils + +import org.apache.spark.sql.types._ +import za.co.absa.spark.commons.adapters.HofsAdapter + + +object SchemaUtils extends HofsAdapter { + + /** + * Returns the parent path of a field. Returns an empty string if a root level field name is provided. + * + * @param columnName A fully qualified column name + * @return The parent column name or an empty string if the input column is a root level column + */ + def getParentPath(columnName: String): String = { + val index = columnName.lastIndexOf('.') + if (index > 0) { + columnName.substring(0, index) + } else { + "" + } + } + + /** + * Get paths for all array subfields of this given datatype + * @param path The path to the attribute + * @param name The name of the attribute + * @param dt Data type to be checked + * + */ + def getAllArraySubPaths(path: String, name: String, dt: DataType): Seq[String] = { + val currPath = appendPath(path, name) + dt match { + case s: StructType => s.fields.flatMap(f => getAllArraySubPaths(currPath, f.name, f.dataType)) + case _@ArrayType(elType, _) => getAllArraySubPaths(path, name, elType) :+ currPath + case _ => Seq() + } + } + + /** + * For a given list of field paths determines if any path pair is a subset of one another. + * + * For instance, + * - 'a.b', 'a.b.c', 'a.b.c.d' have this property. + * - 'a.b', 'a.b.c', 'a.x.y' does NOT have it, since 'a.b.c' and 'a.x.y' have diverging subpaths. + * + * @param paths A list of paths to be analyzed + * @return true if for all pathe the above property holds + */ + def isCommonSubPath(paths: String*): Boolean = { + def sliceRoot(paths: Seq[Seq[String]]): Seq[Seq[String]] = { + paths.map(path => path.drop(1)).filter(_.nonEmpty) + } + + var isParentCommon = true // For Seq() the property holds by [my] convention + var restOfPaths: Seq[Seq[String]] = paths.map(_.split('.').toSeq).filter(_.nonEmpty) + while (isParentCommon && restOfPaths.nonEmpty) { + val parent = restOfPaths.head.head + isParentCommon = restOfPaths.forall(path => path.head == parent) + restOfPaths = sliceRoot(restOfPaths) + } + isParentCommon + } + + /** + * Append a new attribute to path or empty string. + * + * @param path The dot-separated existing path + * @param fieldName Name of the field to be appended to the path + * @return The path with the new field appended or the field itself if path is empty + */ + def appendPath(path: String, fieldName: String): String = { + if (path.isEmpty) { + fieldName + } else if (fieldName.isEmpty) { + path + } else { + s"$path.$fieldName" + } + } + +} diff --git a/src/test/scala/za/co/absa/spark/commons/implicits/ColumnImplicitsTest.scala b/src/test/scala/za/co/absa/spark/commons/implicits/ColumnImplicitsTest.scala index 20d1e62b..f075eafa 100644 --- a/src/test/scala/za/co/absa/spark/commons/implicits/ColumnImplicitsTest.scala +++ b/src/test/scala/za/co/absa/spark/commons/implicits/ColumnImplicitsTest.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.functions.lit import org.scalatest.funsuite.AnyFunSuite import za.co.absa.spark.commons.implicits.ColumnImplicits.ColumnEnhancements -class ColumnImplicitsTest extends AnyFunSuite{ +class ColumnImplicitsTest extends AnyFunSuite { private val column: Column = lit("abcdefgh") diff --git a/src/test/scala/za/co/absa/spark/commons/implicits/DataFrameImplicitsSuite.scala b/src/test/scala/za/co/absa/spark/commons/implicits/DataFrameImplicitsTest.scala similarity index 89% rename from src/test/scala/za/co/absa/spark/commons/implicits/DataFrameImplicitsSuite.scala rename to src/test/scala/za/co/absa/spark/commons/implicits/DataFrameImplicitsTest.scala index 801751ea..3de27353 100644 --- a/src/test/scala/za/co/absa/spark/commons/implicits/DataFrameImplicitsSuite.scala +++ b/src/test/scala/za/co/absa/spark/commons/implicits/DataFrameImplicitsTest.scala @@ -16,13 +16,12 @@ package za.co.absa.spark.commons.implicits -import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions.lit +import org.apache.spark.sql.{AnalysisException, DataFrame} import org.scalatest.funsuite.AnyFunSuite import za.co.absa.spark.commons.test.SparkTestBase -class DataFrameImplicitsSuite extends AnyFunSuite with SparkTestBase { - +class DataFrameImplicitsTest extends AnyFunSuite with SparkTestBase with JsonTestData { import spark.implicits._ private val columnName = "data" @@ -55,6 +54,7 @@ class DataFrameImplicitsSuite extends AnyFunSuite with SparkTestBase { "y", "z" ) + private val inputData = inputDataSeq.toDF(columnName) import za.co.absa.spark.commons.implicits.DataFrameImplicits.DataFrameEnhancements @@ -235,4 +235,23 @@ class DataFrameImplicitsSuite extends AnyFunSuite with SparkTestBase { assert(dfIn.schema.head.name == "value") assert(actualOutput == expectedOutput) } + + test("order schemas for equal schemas") { + val dfA = spark.read.json(Seq(jsonA).toDS) + val dfC = spark.read.json(Seq(jsonC).toDS).select("legs", "id", "key") + + val dfA2Aligned = dfC.alignSchema(dfA.schema) + + assert(dfA.columns.toSeq == dfA2Aligned.columns.toSeq) + assert(dfA.select("key").columns.toSeq == dfA2Aligned.select("key").columns.toSeq) + } + + test("throw an error for DataFrames with different schemas") { + val dfA = spark.read.json(Seq(jsonA).toDS) + val dfB = spark.read.json(Seq(jsonB).toDS) + + assertThrows[AnalysisException]{ + dfA.alignSchema(dfB.schema) + } + } } diff --git a/src/test/scala/za/co/absa/spark/commons/implicits/DataTypeImplicitsTest.scala b/src/test/scala/za/co/absa/spark/commons/implicits/DataTypeImplicitsTest.scala new file mode 100644 index 00000000..1049667e --- /dev/null +++ b/src/test/scala/za/co/absa/spark/commons/implicits/DataTypeImplicitsTest.scala @@ -0,0 +1,68 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.spark.commons.implicits + +import org.apache.spark.sql.types._ +import org.scalatest.funsuite.AnyFunSuite +import za.co.absa.spark.commons.implicits.DataTypeImplicits.DataTypeEnhancements + +class DataTypeImplicitsTest extends AnyFunSuite with JsonTestData { + + test ("Test doesCastAlwaysSucceed()") { + assert(!StructType(Seq()).doesCastAlwaysSucceed(StringType)) + assert(!ArrayType(StringType).doesCastAlwaysSucceed(StringType)) + assert(!StringType.doesCastAlwaysSucceed(ByteType)) + assert(!StringType.doesCastAlwaysSucceed(ShortType)) + assert(!StringType.doesCastAlwaysSucceed(IntegerType)) + assert(!StringType.doesCastAlwaysSucceed(LongType)) + assert(!StringType.doesCastAlwaysSucceed(DecimalType(10,10))) + assert(!StringType.doesCastAlwaysSucceed(DateType)) + assert(!StringType.doesCastAlwaysSucceed(TimestampType)) + assert(!StructType(Seq()).doesCastAlwaysSucceed(StructType(Seq()))) + assert(!ArrayType(StringType).doesCastAlwaysSucceed(ArrayType(StringType))) + + assert(!ShortType.doesCastAlwaysSucceed(ByteType)) + assert(!IntegerType.doesCastAlwaysSucceed(ByteType)) + assert(!IntegerType.doesCastAlwaysSucceed(ShortType)) + assert(!LongType.doesCastAlwaysSucceed(ByteType)) + assert(!LongType.doesCastAlwaysSucceed(ShortType)) + assert(!LongType.doesCastAlwaysSucceed(IntegerType)) + + assert(StringType.doesCastAlwaysSucceed(StringType)) + assert(ByteType.doesCastAlwaysSucceed(StringType)) + assert(ShortType.doesCastAlwaysSucceed(StringType)) + assert(IntegerType.doesCastAlwaysSucceed(StringType)) + assert(LongType.doesCastAlwaysSucceed(StringType)) + assert(DecimalType(10,10).doesCastAlwaysSucceed(StringType)) + assert(DateType.doesCastAlwaysSucceed(StringType)) + assert(TimestampType.doesCastAlwaysSucceed(StringType)) + assert(StringType.doesCastAlwaysSucceed(StringType)) + + assert(ByteType.doesCastAlwaysSucceed(ByteType)) + assert(ByteType.doesCastAlwaysSucceed(ShortType)) + assert(ByteType.doesCastAlwaysSucceed(IntegerType)) + assert(ByteType.doesCastAlwaysSucceed(LongType)) + assert(ShortType.doesCastAlwaysSucceed(ShortType)) + assert(ShortType.doesCastAlwaysSucceed(IntegerType)) + assert(ShortType.doesCastAlwaysSucceed(LongType)) + assert(IntegerType.doesCastAlwaysSucceed(IntegerType)) + assert(IntegerType.doesCastAlwaysSucceed(LongType)) + assert(LongType.doesCastAlwaysSucceed(LongType)) + assert(DateType.doesCastAlwaysSucceed(TimestampType)) + } + +} diff --git a/src/test/scala/za/co/absa/spark/commons/implicits/JsonTestData.scala b/src/test/scala/za/co/absa/spark/commons/implicits/JsonTestData.scala new file mode 100644 index 00000000..b7c21c34 --- /dev/null +++ b/src/test/scala/za/co/absa/spark/commons/implicits/JsonTestData.scala @@ -0,0 +1,66 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.spark.commons.implicits + +import org.apache.spark.sql.types._ + +trait JsonTestData { + + protected val jsonA = """[{"id":1,"legs":[{"legid":100,"conditions":[{"checks":[{"checkNums":["1","2","3b","4","5c","6"]}],"amount":100}]}], "key" : {"alfa": "1", "beta": {"beta2": "2"}} }]""" + protected val jsonB = """[{"id":1,"legs":[{"legid":100,"conditions":[{"checks":[{"checkNums":["1","2","3b","4","5c","6"]}],"amount":100,"price":10}]}]}]""" + protected val jsonC = """[{"legs":[{"legid":100,"conditions":[{"amount":100,"checks":[{"checkNums":["1","2","3b","4","5c","6"]}]}]}],"id":1, "key" : {"beta": {"beta2": "2"}, "alfa": "1"} }]""" + protected val jsonD = """[{"legs":[{"legid":100,"conditions":[{"amount":100,"checks":[{"checkNums":["1","2","3b","4","5c","6"]}]}]}],"id":1, "key" : {"beta": {"beta2": 2}, "alfa": 1} }]""" + protected val jsonE = """[{"legs":[{"legid":100,"conditions":[{"amount":100,"checks":[{"checkNums":["1","2","3b","4","5c","6"]}]}]}],"id":1, "key" : {"beta": {"beta2": 2}, "alfa": 1}, "extra" : "a"}]""" + + protected val sample = + """{"id":1,"legs":[{"legid":100,"conditions":[{"checks":[{"checkNums":["1","2","3b","4","5c","6"]}],"amount":100}]}]}""" :: + """{"id":2,"legs":[{"legid":200,"conditions":[{"checks":[{"checkNums":["8","9","10b","11","12c","13"]}],"amount":200}]}]}""" :: + """{"id":3,"legs":[{"legid":300,"conditions":[{"checks":[],"amount": 300}]}]}""" :: + """{"id":4,"legs":[{"legid":400,"conditions":[{"checks":null,"amount": 400}]}]}""" :: + """{"id":5,"legs":[{"legid":500,"conditions":[]}]}""" :: + """{"id":6,"legs":[]}""" :: + """{"id":7}""" :: Nil + + protected val sampleA: Seq[String] = sample :+ """{"id":1,"legs":[{"legid":100,"conditions":[{"checks":[{"checkNums":["1","2","3b","4","5c","6"]}],"amount":100}]}], "key" : {"alfa": "1", "beta": {"beta2": "2"}}}""" + protected val sampleE: Seq[String] = sample :+ """{"legs":[{"legid":100,"conditions":[{"amount":100,"checks":[{"checkNums":["1","2","3b","4","5c","6"]}]}]}],"id":1, "key" : {"beta": {"beta2": 2}, "alfa": 1}, "extra" : "a"}""" + + protected val schema = StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", StructType(Seq( + StructField("c", IntegerType), + StructField("d", StructType(Seq( + StructField("e", IntegerType))), nullable = true)))), + StructField("f", StructType(Seq( + StructField("g", ArrayType.apply(StructType(Seq( + StructField("h", IntegerType)))))))))) + + protected val nestedSchema = StructType(Seq( + StructField("a", IntegerType), + StructField("b", ArrayType(StructType(Seq( + StructField("c", StructType(Seq( + StructField("d", ArrayType(StructType(Seq( + StructField("e", IntegerType)))))))))))))) + + protected val arrayOfArraysSchema = StructType(Seq( + StructField("a", ArrayType(ArrayType(IntegerType)), nullable = false), + StructField("b", ArrayType(ArrayType(StructType(Seq( + StructField("c", StringType, nullable = false) + )) + )), nullable = true) + )) + +} diff --git a/src/test/scala/za/co/absa/spark/commons/implicits/StructFieldImplicitsTest.scala b/src/test/scala/za/co/absa/spark/commons/implicits/StructFieldImplicitsTest.scala index 055afa9c..d39a28d0 100644 --- a/src/test/scala/za/co/absa/spark/commons/implicits/StructFieldImplicitsTest.scala +++ b/src/test/scala/za/co/absa/spark/commons/implicits/StructFieldImplicitsTest.scala @@ -58,5 +58,4 @@ class StructFieldImplicitsTest extends AnyFunSuite { assertResult(true)(fieldWith("\"hvh\"").metadata.hasKey("a")) assertResult(true)(fieldWith("null").metadata.hasKey("a")) } - } diff --git a/src/test/scala/za/co/absa/spark/commons/implicits/StructTypeImplicitsArrayTest.scala b/src/test/scala/za/co/absa/spark/commons/implicits/StructTypeImplicitsArrayTest.scala new file mode 100644 index 00000000..620dd8c1 --- /dev/null +++ b/src/test/scala/za/co/absa/spark/commons/implicits/StructTypeImplicitsArrayTest.scala @@ -0,0 +1,197 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.spark.commons.implicits + +import org.apache.spark.sql.types.{ArrayType, DataType, LongType, NumericType, StringType, StructField, StructType} +import org.scalatest.funsuite.AnyFunSuite +import za.co.absa.spark.commons.implicits.StructTypeImplicits.StructTypeEnhancementsArrays +import za.co.absa.spark.commons.test.SparkTestBase + +class StructTypeImplicitsArrayTest extends AnyFunSuite with SparkTestBase with JsonTestData { + + import spark.implicits._ + val df = spark.read.json(sample.toDS) + + val dfA = spark.read.json(sampleA.toDS) + val dfE = spark.read.json(sampleE.toDS) + + test("Testing getFirstArrayPath") { + assertResult("f.g")(schema.getFirstArrayPath("f.g.h")) + assertResult("f.g")(schema.getFirstArrayPath("f.g")) + assertResult("")(schema.getFirstArrayPath("z.x.y")) + assertResult("")(schema.getFirstArrayPath("b.c.d.e")) + } + + test("Testing getAllArraysInPath") { + assertResult(Seq("b", "b.c.d"))(nestedSchema.getAllArraysInPath("b.c.d.e")) + } + + test("Test isNonNestedArray") { + assert(df.schema.isNonNestedArray("legs")) + assert(!df.schema.isNonNestedArray("legs.conditions")) + assert(!df.schema.isNonNestedArray("legs.conditions.checks")) + assert(!df.schema.isNonNestedArray("legs.conditions.checks.checkNums")) + assert(!df.schema.isNonNestedArray("id")) + assert(!df.schema.isNonNestedArray("legs.legid")) + } + + test("Test isOfType ArrayType") { + assert(df.schema.isOfType[ArrayType]("legs")) + assert(df.schema.isOfType[ArrayType]("legs.conditions")) + assert(df.schema.isOfType[ArrayType]("legs.conditions.checks")) + assert(df.schema.isOfType[ArrayType]("legs.conditions.checks.checkNums")) + + assert(!df.schema.isOfType[ArrayType]("id")) + assert(!df.schema.isOfType[ArrayType]("legs.legid")) + } + + test("Test isOfType LongType") { + assert(dfA.schema.isOfType[LongType]("legs.legid")) + assert(dfA.schema.isOfType[LongType]("id")) + assert(!dfA.schema.isOfType[StringType]("id")) + assert(dfA.schema.isOfType[NumericType]("id")) + assert(dfA.schema.isOfType[DataType]("id")) + } + + test("Test isOfType StructType") { + assert(!dfE.schema.isOfType[StructType]("id")) + assert(dfE.schema.isOfType[StructType]("key")) + assert(!dfE.schema.isOfType[StructType]("key.alfa")) + assert(dfE.schema.isOfType[StructType]("key.beta")) + assert(!dfE.schema.isOfType[StructType]("key.beta.beta2")) + assert(!dfE.schema.isOfType[StructType]("extra")) + } + + test("Test isOfType StringType") { + assert(!dfA.schema.isOfType[StringType]("id")) + assert(!dfA.schema.isOfType[StringType]("key")) + assert(dfA.schema.isOfType[StringType]("key.alfa")) + assert(!dfA.schema.isOfType[StringType]("key.beta")) + assert(dfA.schema.isOfType[StringType]("key.beta.beta2")) + assert(!dfA.schema.isOfType[StringType]("extra")) + + assert(!dfE.schema.isOfType[StringType]("key.alfa")) + assert(dfE.schema.isOfType[StringType]("extra")) + } + + test("Test getDeepestCommonArrayPath() for a path without an array") { + val schema = StructType(Seq[StructField]( + StructField("a", + StructType(Seq[StructField]( + StructField("b", StringType)) + )))) + + assert(schema.getDeepestCommonArrayPath(Seq("a", "a.b")).isEmpty) + } + + test("Test getDeepestCommonArrayPath() for a path with a single array at top level") { + val schema = StructType(Seq[StructField]( + StructField("a", ArrayType(StructType(Seq[StructField]( + StructField("b", StringType))) + )))) + + val deepestPath = schema.getDeepestCommonArrayPath(Seq("a", "a.b")) + + assert(deepestPath.nonEmpty) + assert(deepestPath.get == "a") + } + + test("Test getDeepestCommonArrayPath() for a path with a single array at nested level") { + val schema = StructType(Seq[StructField]( + StructField("a", StructType(Seq[StructField]( + StructField("b", ArrayType(StringType)))) + ))) + + val deepestPath = schema.getDeepestCommonArrayPath(Seq("a", "a.b")) + + assert(deepestPath.nonEmpty) + assert(deepestPath.get == "a.b") + } + + test("Test getDeepestCommonArrayPath() for a path with several nested arrays of struct") { + val schema = StructType(Seq[StructField]( + StructField("a", ArrayType(StructType(Seq[StructField]( + StructField("b", StructType(Seq[StructField]( + StructField("c", ArrayType(StructType(Seq[StructField]( + StructField("d", StructType(Seq[StructField]( + StructField("e", StringType)) + ))) + )))) + ))) + ))))) + + val deepestPath = schema.getDeepestCommonArrayPath(Seq("a", "a.b", "a.b.c.d.e", "a.b.c.d")) + + assert(deepestPath.nonEmpty) + assert(deepestPath.get == "a.b.c") + } + + test("Test getDeepestArrayPath() for a path without an array") { + val schema = StructType(Seq[StructField]( + StructField("a", + StructType(Seq[StructField]( + StructField("b", StringType)) + )))) + + assert(schema.getDeepestArrayPath("a.b").isEmpty) + } + + test("Test getDeepestArrayPath() for a path with a single array at top level") { + val schema = StructType(Seq[StructField]( + StructField("a", ArrayType(StructType(Seq[StructField]( + StructField("b", StringType))) + )))) + + val deepestPath = schema.getDeepestArrayPath("a.b") + + assert(deepestPath.nonEmpty) + assert(deepestPath.get == "a") + } + + test("Test getDeepestArrayPath() for a path with a single array at nested level") { + val schema = StructType(Seq[StructField]( + StructField("a", StructType(Seq[StructField]( + StructField("b", ArrayType(StringType)))) + ))) + + val deepestPath = schema.getDeepestArrayPath("a.b") + val deepestPath2 = schema.getDeepestArrayPath("a") + + assert(deepestPath.nonEmpty) + assert(deepestPath.get == "a.b") + assert(deepestPath2.isEmpty) + } + + test("Test getDeepestArrayPath() for a path with several nested arrays of struct") { + val schema = StructType(Seq[StructField]( + StructField("a", ArrayType(StructType(Seq[StructField]( + StructField("b", StructType(Seq[StructField]( + StructField("c", ArrayType(StructType(Seq[StructField]( + StructField("d", StructType(Seq[StructField]( + StructField("e", StringType)) + ))) + )))) + ))) + ))))) + + val deepestPath = schema.getDeepestArrayPath("a.b.c.d.e") + + assert(deepestPath.nonEmpty) + assert(deepestPath.get == "a.b.c") + } + +} diff --git a/src/test/scala/za/co/absa/spark/commons/implicits/StructTypeImplicitsTest.scala b/src/test/scala/za/co/absa/spark/commons/implicits/StructTypeImplicitsTest.scala new file mode 100644 index 00000000..017b3bf3 --- /dev/null +++ b/src/test/scala/za/co/absa/spark/commons/implicits/StructTypeImplicitsTest.scala @@ -0,0 +1,219 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.spark.commons.implicits + +import org.apache.spark.sql.types._ +import org.scalatest.funsuite.AnyFunSuite +import za.co.absa.spark.commons.implicits.StructTypeImplicits.StructTypeEnhancements +import za.co.absa.spark.commons.test.SparkTestBase + +class StructTypeImplicitsTest extends AnyFunSuite with SparkTestBase with JsonTestData { + // scalastyle:off magic.number + + test("Testing getFieldType") { + import za.co.absa.spark.commons.implicits.StructTypeImplicits.StructTypeEnhancements + + val a = schema.getFieldType("a") + val b = schema.getFieldType("b") + val c = schema.getFieldType("b.c") + val d = schema.getFieldType("b.d") + val e = schema.getFieldType("b.d.e") + val f = schema.getFieldType("f") + val g = schema.getFieldType("f.g") + val h = schema.getFieldType("f.g.h") + + assert(a.get.isInstanceOf[IntegerType]) + assert(b.get.isInstanceOf[StructType]) + assert(c.get.isInstanceOf[IntegerType]) + assert(d.get.isInstanceOf[StructType]) + assert(e.get.isInstanceOf[IntegerType]) + assert(f.get.isInstanceOf[StructType]) + assert(g.get.isInstanceOf[ArrayType]) + assert(h.get.isInstanceOf[IntegerType]) + assert(schema.getFieldType("z").isEmpty) + assert(schema.getFieldType("x.y.z").isEmpty) + assert(schema.getFieldType("f.g.h.a").isEmpty) + } + + test("Testing fieldExists") { + assert(schema.fieldExists("a")) + assert(schema.fieldExists("b")) + assert(schema.fieldExists("b.c")) + assert(schema.fieldExists("b.d")) + assert(schema.fieldExists("b.d.e")) + assert(schema.fieldExists("f")) + assert(schema.fieldExists("f.g")) + assert(schema.fieldExists("f.g.h")) + assert(!schema.fieldExists("z")) + assert(!schema.fieldExists("x.y.z")) + assert(!schema.fieldExists("f.g.h.a")) + } + + test("Test isColumnArrayOfStruct") { + assert(!schema.isColumnArrayOfStruct("a")) + assert(!schema.isColumnArrayOfStruct("b")) + assert(!schema.isColumnArrayOfStruct("b.c")) + assert(!schema.isColumnArrayOfStruct("b.d")) + assert(!schema.isColumnArrayOfStruct("b.d.e")) + assert(!schema.isColumnArrayOfStruct("f")) + assert(schema.isColumnArrayOfStruct("f.g")) + assert(!schema.isColumnArrayOfStruct("f.g.h")) + assert(!nestedSchema.isColumnArrayOfStruct("a")) + assert(nestedSchema.isColumnArrayOfStruct("b")) + assert(nestedSchema.isColumnArrayOfStruct("b.c.d")) + } + + + test("Testing getAllArrayPaths") { + assertResult(Seq("f.g"))(schema.getAllArrayPaths()) + val newSchema = schema("b").dataType.asInstanceOf[StructType] + assertResult(Seq())(newSchema.getAllArrayPaths()) + } + + test("Testing getFieldNullability") { + assert(schema.getFieldNullability("b.d").get) + assert(schema.getFieldNullability("x.y.z").isEmpty) + } + + test("Test isOnlyField()") { + val schema = StructType(Seq[StructField]( + StructField("a", StringType), + StructField("b", StructType(Seq[StructField]( + StructField("e", StringType), + StructField("f", StringType) + ))), + StructField("c", StructType(Seq[StructField]( + StructField("d", StringType) + ))) + )) + + assert(!schema.isOnlyField("a")) + assert(!schema.isOnlyField("b.e")) + assert(!schema.isOnlyField("b.f")) + assert(schema.isOnlyField("c.d")) + } + + test("Test getStructField on array of arrays") { + assert(arrayOfArraysSchema.getField("a").contains(StructField("a", ArrayType(ArrayType(IntegerType)), nullable = false))) + assert(arrayOfArraysSchema.getField("b").contains(StructField("b", ArrayType(ArrayType(StructType(Seq(StructField("c", StringType, nullable = false))))), nullable = true))) + assert(arrayOfArraysSchema.getField("b.c").contains(StructField("c", StringType, nullable = false))) + assert(arrayOfArraysSchema.getField("b.d").isEmpty) + } + + test("Test fieldExists") { + assert(schema.fieldExists("a")) + assert(schema.fieldExists("b")) + assert(schema.fieldExists("b.c")) + assert(schema.fieldExists("b.d")) + assert(schema.fieldExists("b.d.e")) + assert(schema.fieldExists("f")) + assert(schema.fieldExists("f.g")) + assert(schema.fieldExists("f.g.h")) + assert(!schema.fieldExists("z")) + assert(!schema.fieldExists("x.y.z")) + assert(!schema.fieldExists("f.g.h.a")) + + assert(arrayOfArraysSchema.fieldExists("a")) + assert(arrayOfArraysSchema.fieldExists("b")) + assert(arrayOfArraysSchema.fieldExists("b.c")) + assert(!arrayOfArraysSchema.fieldExists("b.d")) + } + + test("Test getClosestUniqueName() is working properly") { + val schema = StructType(Seq[StructField]( + StructField("value", StringType), + StructField("value_1", StringType), + StructField("value_2", StringType) + )) + + // A column name that does not exist + val name1 = schema.getClosestUniqueName("v") + // A column that exists + val name2 = schema.getClosestUniqueName("value") + + assert(name1 == "v") + assert(name2 == "value_3") + } + + import spark.implicits._ + + test("be true for E in D but not vice versa") { + val schemaD = spark.read.json(Seq(jsonD).toDS).schema + val schemaE = spark.read.json(Seq(jsonE).toDS).schema + + assert(schemaD.isSubset(schemaE)) + assert(!schemaE.isSubset(schemaD)) + } + + test("be false for A and B in both directions"){ + val schemaA = spark.read.json(Seq(jsonA).toDS).schema + val schemaB = spark.read.json(Seq(jsonA).toDS).schema + + assert(schemaA.isSubset(schemaB)) + assert(schemaB.isSubset(schemaA)) + } + + test("say true for the same schemas") { + val dfA1 = spark.read.json(Seq(jsonA).toDS) + val dfA2 = spark.read.json(Seq(jsonA).toDS) + + assert(dfA1.schema.isEquivalent(dfA2.schema)) + } + + test("say false when first utils has an extra field") { + val dfA = spark.read.json(Seq(jsonA).toDS) + val dfB = spark.read.json(Seq(jsonB).toDS) + + assert(!dfA.schema.isEquivalent(dfB.schema)) + } + + test("say false when second utils has an extra field") { + val dfA = spark.read.json(Seq(jsonA).toDS) + val dfB = spark.read.json(Seq(jsonB).toDS) + + assert(!dfB.schema.isEquivalent(dfA.schema)) + } + + test("produce a list of differences with path for schemas with different columns") { + val schemaA = spark.read.json(Seq(jsonA).toDS).schema + val schemaB = spark.read.json(Seq(jsonB).toDS).schema + + assertResult(schemaA.diffSchema(schemaB))(List("key cannot be found in both schemas")) + assertResult(schemaB.diffSchema(schemaA))(List("legs.conditions.price cannot be found in both schemas")) + } + + test("produce a list of differences with path for schemas with different column types") { + val schemaC = spark.read.json(Seq(jsonC).toDS).schema + val schemaD = spark.read.json(Seq(jsonD).toDS).schema + + val result = List( + "key.alfa data type doesn't match (string) vs (long)", + "key.beta.beta2 data type doesn't match (string) vs (long)" + ) + + assertResult(schemaC.diffSchema(schemaD))(result) + } + + test("produce an empty list for identical schemas") { + val schemaA = spark.read.json(Seq(jsonA).toDS).schema + val schemaB = spark.read.json(Seq(jsonA).toDS).schema + + assert(schemaA.diffSchema(schemaB).isEmpty) + assert(schemaB.diffSchema(schemaA).isEmpty) + } + +} diff --git a/src/test/scala/za/co/absa/spark/commons/schema/SchemaUtilsSpec.scala b/src/test/scala/za/co/absa/spark/commons/schema/SchemaUtilsSpec.scala deleted file mode 100644 index 219f0e6f..00000000 --- a/src/test/scala/za/co/absa/spark/commons/schema/SchemaUtilsSpec.scala +++ /dev/null @@ -1,127 +0,0 @@ -/* - * Copyright 2021 ABSA Group Limited - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package za.co.absa.spark.commons.schema - -import org.apache.spark.sql.AnalysisException -import org.scalatest.BeforeAndAfterAll -import org.scalatest.flatspec.AnyFlatSpec -import org.scalatest.matchers.should.Matchers -import za.co.absa.spark.commons.test.SparkTestBase - -class SchemaUtilsSpec extends AnyFlatSpec with Matchers with BeforeAndAfterAll with SparkTestBase { - - - import spark.implicits._ - - val jsonA = """[{"id":1,"legs":[{"legid":100,"conditions":[{"checks":[{"checkNums":["1","2","3b","4","5c","6"]}],"amount":100}]}], "key" : {"alfa": "1", "beta": {"beta2": "2"}} }]""" - val jsonB = """[{"id":1,"legs":[{"legid":100,"conditions":[{"checks":[{"checkNums":["1","2","3b","4","5c","6"]}],"amount":100,"price":10}]}]}]""" - val jsonC = """[{"legs":[{"legid":100,"conditions":[{"amount":100,"checks":[{"checkNums":["1","2","3b","4","5c","6"]}]}]}],"id":1, "key" : {"beta": {"beta2": "2"}, "alfa": "1"} }]""" - val jsonD = """[{"legs":[{"legid":100,"conditions":[{"amount":100,"checks":[{"checkNums":["1","2","3b","4","5c","6"]}]}]}],"id":1, "key" : {"beta": {"beta2": 2}, "alfa": 1} }]""" - val jsonE = """[{"legs":[{"legid":100,"conditions":[{"amount":100,"checks":[{"checkNums":["1","2","3b","4","5c","6"]}]}]}],"id":1, "key" : {"beta": {"beta2": 2}, "alfa": 1}, "extra" : "a"}]""" - - behavior of "isSameSchema" - - it should "say true for the same schemas" in { - val dfA1 = spark.read.json(Seq(jsonA).toDS) - val dfA2 = spark.read.json(Seq(jsonA).toDS) - - SchemaUtils.equivalentSchemas(dfA1.schema, dfA2.schema) should be(true) - } - - it should "say false when first schema has an extra field" in { - val dfA = spark.read.json(Seq(jsonA).toDS) - val dfB = spark.read.json(Seq(jsonB).toDS) - - SchemaUtils.equivalentSchemas(dfA.schema, dfB.schema) should be(false) - } - - it should "say false when second schema has an extra field" in { - val dfA = spark.read.json(Seq(jsonA).toDS) - val dfB = spark.read.json(Seq(jsonB).toDS) - - SchemaUtils.equivalentSchemas(dfB.schema, dfA.schema) should be(false) - } - - behavior of "alignSchema" - - it should "order schemas for equal schemas" in { - val dfA = spark.read.json(Seq(jsonA).toDS) - val dfC = spark.read.json(Seq(jsonC).toDS).select("legs", "id", "key") - - val dfA2Aligned = SchemaUtils.alignSchema(dfC, dfA.schema) - - dfA.columns.toSeq should equal(dfA2Aligned.columns.toSeq) - dfA.select("key").columns.toSeq should equal(dfA2Aligned.select("key").columns.toSeq) - } - - it should "throw an error for DataFrames with different schemas" in { - val dfA = spark.read.json(Seq(jsonA).toDS) - val dfB = spark.read.json(Seq(jsonB).toDS) - - intercept[AnalysisException] { - SchemaUtils.alignSchema(dfA, dfB.schema) - } - } - - behavior of "diffSchema" - - it should "produce a list of differences with path for schemas with different columns" in { - val schemaA = spark.read.json(Seq(jsonA).toDS).schema - val schemaB = spark.read.json(Seq(jsonB).toDS).schema - - SchemaUtils.diffSchema(schemaA, schemaB) should equal(List("key cannot be found in both schemas")) - SchemaUtils.diffSchema(schemaB, schemaA) should equal(List("legs.conditions.price cannot be found in both schemas")) - } - - it should "produce a list of differences with path for schemas with different column types" in { - val schemaC = spark.read.json(Seq(jsonC).toDS).schema - val schemaD = spark.read.json(Seq(jsonD).toDS).schema - - val result = List( - "key.alfa data type doesn't match (string) vs (long)", - "key.beta.beta2 data type doesn't match (string) vs (long)" - ) - - SchemaUtils.diffSchema(schemaC, schemaD) should equal(result) - } - - it should "produce an empty list for identical schemas" in { - val schemaA = spark.read.json(Seq(jsonA).toDS).schema - val schemaB = spark.read.json(Seq(jsonA).toDS).schema - - SchemaUtils.diffSchema(schemaA, schemaB).isEmpty should be(true) - SchemaUtils.diffSchema(schemaB, schemaA).isEmpty should be(true) - } - - behavior of "isSubset" - - it should "be true for E in D but not vice versa" in { - val schemaD = spark.read.json(Seq(jsonD).toDS).schema - val schemaE = spark.read.json(Seq(jsonE).toDS).schema - - SchemaUtils.isSubset(schemaD, schemaE) should be(true) - SchemaUtils.isSubset(schemaE, schemaD) should be(false) - } - - it should "be false for A and B in both directions" in { - val schemaA = spark.read.json(Seq(jsonA).toDS).schema - val schemaB = spark.read.json(Seq(jsonA).toDS).schema - - SchemaUtils.isSubset(schemaA, schemaB) should be(true) - SchemaUtils.isSubset(schemaB, schemaA) should be(true) - } -} diff --git a/src/test/scala/za/co/absa/spark/commons/utils/SchemaUtilsTest.scala b/src/test/scala/za/co/absa/spark/commons/utils/SchemaUtilsTest.scala new file mode 100644 index 00000000..299a712f --- /dev/null +++ b/src/test/scala/za/co/absa/spark/commons/utils/SchemaUtilsTest.scala @@ -0,0 +1,39 @@ +/* + * Copyright 2021 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.spark.commons.utils + +import org.scalatest.funsuite.AnyFunSuite +import org.scalatest.matchers.should.Matchers +import za.co.absa.spark.commons.utils.SchemaUtils._ + +class SchemaUtilsTest extends AnyFunSuite with Matchers { + // scalastyle:off magic.number + + test("Test isCommonSubPath()") { + assert (isCommonSubPath()) + assert (isCommonSubPath("a")) + assert (isCommonSubPath("a.b.c.d.e.f", "a.b.c.d", "a.b.c", "a.b", "a")) + assert (!isCommonSubPath("a.b.c.d.e.f", "a.b.c.x", "a.b.c", "a.b", "a")) + } + + test("Test getParentPath") { + assertResult(getParentPath("a.b.c.d.e"))("a.b.c.d") + assertResult(getParentPath("a"))("") + assertResult(getParentPath("a.bcd"))("a") + assertResult(getParentPath(""))("") + } +}