Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -566,6 +566,10 @@ public int hashCode() {
return Murmur3_x86_32.hashUnsafeWords(baseObject, baseOffset, sizeInBytes, 42);
}

public int hashCode(int seed) {
return Murmur3_x86_32.hashUnsafeWords(baseObject, baseOffset, sizeInBytes, seed);
}

@Override
public boolean equals(Object other) {
if (other instanceof UnsafeRow) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ trait FunctionRegistry {

class SimpleFunctionRegistry extends FunctionRegistry {

private val functionBuilders =
private[sql] val functionBuilders =
StringKeyHashMap[(ExpressionInfo, FunctionBuilder)](caseSensitive = false)

override def registerFunction(
Expand Down Expand Up @@ -278,6 +278,7 @@ object FunctionRegistry {
// misc functions
expression[Crc32]("crc32"),
expression[Md5]("md5"),
expression[Murmur3Hash]("hash"),
expression[Sha1]("sha"),
expression[Sha1]("sha1"),
expression[Sha2]("sha2"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ import java.util.zip.CRC32

import org.apache.commons.codec.digest.DigestUtils

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
Expand Down Expand Up @@ -177,3 +179,45 @@ case class Crc32(child: Expression) extends UnaryExpression with ImplicitCastInp
})
}
}

/**
* A function that calculates hash value for a group of expressions. Note that the `seed` argument
* is not exposed to users and should only be set inside spark SQL.
*
* Internally this function will write arguments into an [[UnsafeRow]], and calculate hash code of
* the unsafe row using murmur3 hasher with a seed.
* We should use this hash function for both shuffle and bucket, so that we can guarantee shuffle
* and bucketing have same data distribution.
*/
case class Murmur3Hash(children: Seq[Expression], seed: Int) extends Expression {
def this(arguments: Seq[Expression]) = this(arguments, 42)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use 42 as default seed, which is same with UnsafeRow.hashCode, should we make 42 a constant variable in Murmur3_x86_32?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is fine.

Can you file a follow up jira to look at this again? I think we want to remove the projection to unsafe row soon (before we ship this and persist metadata that way). This should be decoupled from unsafe row ideally. For example, if the row is (int, double, string): the generated hash function shoudl be something like

int hash = seed;
hash = murmur3(getInt(0), hash)
hash = murmur3(getDouble(1), hash)
hash = murmur3(getString(2), hash)
return hash

This is likely not the currently computed hash value so can't defer this for too long.


override def dataType: DataType = IntegerType

override def foldable: Boolean = children.forall(_.foldable)

override def nullable: Boolean = false

override def checkInputDataTypes(): TypeCheckResult = {
if (children.isEmpty) {
TypeCheckResult.TypeCheckFailure("arguments of function hash cannot be empty")
} else {
TypeCheckResult.TypeCheckSuccess
}
}

private lazy val unsafeProjection = UnsafeProjection.create(children)

override def eval(input: InternalRow): Any = {
unsafeProjection(input).hashCode(seed)
}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val unsafeRow = GenerateUnsafeProjection.createCode(ctx, children)
ev.isNull = "false"
s"""
${unsafeRow.code}
final int ${ev.value} = ${unsafeRow.value}.hashCode($seed);
"""
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ class RowEncoderSuite extends SparkFunSuite {
.add("binary", BinaryType)
.add("date", DateType)
.add("timestamp", TimestampType)
.add("udt", new ExamplePointUDT, false))
.add("udt", new ExamplePointUDT))

encodeDecodeTest(
new StructType()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.commons.codec.digest.DigestUtils

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.types.{IntegerType, StringType, BinaryType}
import org.apache.spark.sql.{Row, RandomDataGenerator}
import org.apache.spark.sql.catalyst.encoders.{ExamplePointUDT, RowEncoder}
import org.apache.spark.sql.types._

class MiscFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {

Expand Down Expand Up @@ -59,4 +61,73 @@ class MiscFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(Crc32(Literal.create(null, BinaryType)), null)
checkConsistencyBetweenInterpretedAndCodegen(Crc32, BinaryType)
}

private val structOfString = new StructType().add("str", StringType)
private val structOfUDT = new StructType().add("udt", new ExamplePointUDT, false)
private val arrayOfString = ArrayType(StringType)
private val arrayOfNull = ArrayType(NullType)
private val mapOfString = MapType(StringType, StringType)
private val arrayOfUDT = ArrayType(new ExamplePointUDT, false)

testMurmur3Hash(
new StructType()
.add("null", NullType)
.add("boolean", BooleanType)
.add("byte", ByteType)
.add("short", ShortType)
.add("int", IntegerType)
.add("long", LongType)
.add("float", FloatType)
.add("double", DoubleType)
.add("decimal", DecimalType.SYSTEM_DEFAULT)
.add("string", StringType)
.add("binary", BinaryType)
.add("date", DateType)
.add("timestamp", TimestampType)
.add("udt", new ExamplePointUDT))

testMurmur3Hash(
new StructType()
.add("arrayOfNull", arrayOfNull)
.add("arrayOfString", arrayOfString)
.add("arrayOfArrayOfString", ArrayType(arrayOfString))
.add("arrayOfArrayOfInt", ArrayType(ArrayType(IntegerType)))
.add("arrayOfMap", ArrayType(mapOfString))
.add("arrayOfStruct", ArrayType(structOfString))
.add("arrayOfUDT", arrayOfUDT))

testMurmur3Hash(
new StructType()
.add("mapOfIntAndString", MapType(IntegerType, StringType))
.add("mapOfStringAndArray", MapType(StringType, arrayOfString))
.add("mapOfArrayAndInt", MapType(arrayOfString, IntegerType))
.add("mapOfArray", MapType(arrayOfString, arrayOfString))
.add("mapOfStringAndStruct", MapType(StringType, structOfString))
.add("mapOfStructAndString", MapType(structOfString, StringType))
.add("mapOfStruct", MapType(structOfString, structOfString)))

testMurmur3Hash(
new StructType()
.add("structOfString", structOfString)
.add("structOfStructOfString", new StructType().add("struct", structOfString))
.add("structOfArray", new StructType().add("array", arrayOfString))
.add("structOfMap", new StructType().add("map", mapOfString))
.add("structOfArrayAndMap",
new StructType().add("array", arrayOfString).add("map", mapOfString))
.add("structOfUDT", structOfUDT))

private def testMurmur3Hash(inputSchema: StructType): Unit = {
val inputGenerator = RandomDataGenerator.forType(inputSchema, nullable = false).get
val encoder = RowEncoder(inputSchema)
val seed = scala.util.Random.nextInt()
test(s"murmur3 hash: ${inputSchema.simpleString}") {
for (_ <- 1 to 10) {
val input = encoder.toRow(inputGenerator.apply().asInstanceOf[Row]).asInstanceOf[UnsafeRow]
val literals = input.toSeq(inputSchema).zip(inputSchema.map(_.dataType)).map {
case (value, dt) => Literal.create(value, dt)
}
checkEvaluation(Murmur3Hash(literals, seed), input.hashCode(seed))
}
}
}
}
11 changes: 11 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1863,6 +1863,17 @@ object functions extends LegacyFunctions {
*/
def crc32(e: Column): Column = withExpr { Crc32(e.expr) }

/**
* Calculates the hash code of given columns, and returns the result as a int column.
*
* @group misc_funcs
* @since 2.0
*/
@scala.annotation.varargs
def hash(col: Column, cols: Column*): Column = withExpr {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should just be a single column vararg, rather than one followed by vararg?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the hash function should take at least one parameter, does @scala.annotation.varargs support this?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use the following form:
(firstarg:Int)(more:Int*)

new Murmur3Hash((col +: cols).map(_.expr))
}

//////////////////////////////////////////////////////////////////////////////////////////////
// String functions
//////////////////////////////////////////////////////////////////////////////////////////////
Expand Down
10 changes: 10 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2057,4 +2057,14 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
)
}

test("hash function") {
val df = Seq(1 -> "a", 2 -> "b").toDF("i", "j")
withTempTable("tbl") {
df.registerTempTable("tbl")
checkAnswer(
df.select(hash($"i", $"j")),
sql("SELECT hash(i, j) from tbl")
)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
TestHive.setConf(SQLConf.COLUMN_BATCH_SIZE, 5)
// Enable in-memory partition pruning for testing purposes
TestHive.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, true)
// Use Hive hash expression instead of the native one
TestHive.functionRegistry.unregisterFunction("hash")
RuleExecutor.resetTime()
}

Expand All @@ -62,6 +64,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
Locale.setDefault(originalLocale)
TestHive.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize)
TestHive.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning)
TestHive.functionRegistry.restore()

// For debugging dump some statistics about how much time was spent in various optimizer rules.
logWarning(RuleExecutor.dumpTimeSpent())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,13 @@ import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe

import org.apache.spark.sql.{SQLContext, SQLConf}
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
import org.apache.spark.sql.catalyst.expressions.ExpressionInfo
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.CacheTableCommand
import org.apache.spark.sql.hive._
import org.apache.spark.sql.hive.execution.HiveNativeCommand
import org.apache.spark.sql.hive.client.ClientWrapper
import org.apache.spark.util.{ShutdownHookManager, Utils}
import org.apache.spark.{SparkConf, SparkContext}

Expand Down Expand Up @@ -451,6 +454,27 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) {
logError("FATAL ERROR: Failed to reset TestDB state.", e)
}
}

@transient
override protected[sql] lazy val functionRegistry = new TestHiveFunctionRegistry(
org.apache.spark.sql.catalyst.analysis.FunctionRegistry.builtin.copy(), this.executionHive)
}

private[hive] class TestHiveFunctionRegistry(fr: SimpleFunctionRegistry, client: ClientWrapper)
extends HiveFunctionRegistry(fr, client) {

private val removedFunctions =
collection.mutable.ArrayBuffer.empty[(String, (ExpressionInfo, FunctionBuilder))]

def unregisterFunction(name: String): Unit = {
fr.functionBuilders.remove(name).foreach(f => removedFunctions += name -> f)
}

def restore(): Unit = {
removedFunctions.foreach {
case (name, (info, builder)) => fr.registerFunction(name, info, builder)
}
}
}

private[hive] object TestHiveContext {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -387,9 +387,6 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter {
createQueryTest("partitioned table scan",
"SELECT ds, hr, key, value FROM srcpart")

createQueryTest("hash",
"SELECT hash('test') FROM src LIMIT 1")

createQueryTest("create table as",
"""
|CREATE TABLE createdtable AS SELECT * FROM src;
Expand Down