Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.serializer

import java.io.{DataInput, DataOutput, EOFException, InputStream, IOException, OutputStream}
import java.io._
import java.nio.ByteBuffer
import javax.annotation.Nullable

Expand Down Expand Up @@ -378,18 +378,24 @@ private[serializer] object KryoSerializer {
private val toRegisterSerializer = Map[Class[_], KryoClassSerializer[_]](
classOf[RoaringBitmap] -> new KryoClassSerializer[RoaringBitmap]() {
override def write(kryo: Kryo, output: KryoOutput, bitmap: RoaringBitmap): Unit = {
bitmap.serialize(new KryoOutputDataOutputBridge(output))
bitmap.serialize(new KryoOutputObjectOutputBridge(kryo, output))
}
override def read(kryo: Kryo, input: KryoInput, cls: Class[RoaringBitmap]): RoaringBitmap = {
val ret = new RoaringBitmap
ret.deserialize(new KryoInputDataInputBridge(input))
ret.deserialize(new KryoInputObjectInputBridge(kryo, input))
ret
}
}
)
}

private[serializer] class KryoInputDataInputBridge(input: KryoInput) extends DataInput {
/**
* This is a bridge class to wrap KryoInput as an InputStream and ObjectInput. It forwards all
* methods of InputStream and ObjectInput to KryoInput. It's usually helpful when an API expects
* an InputStream or ObjectInput but you want to use Kryo.
*/
private[spark] class KryoInputObjectInputBridge(
kryo: Kryo, input: KryoInput) extends FilterInputStream(input) with ObjectInput {
override def readLong(): Long = input.readLong()
override def readChar(): Char = input.readChar()
override def readFloat(): Float = input.readFloat()
Expand All @@ -408,9 +414,16 @@ private[serializer] class KryoInputDataInputBridge(input: KryoInput) extends Dat
override def readBoolean(): Boolean = input.readBoolean()
override def readUnsignedByte(): Int = input.readByteUnsigned()
override def readDouble(): Double = input.readDouble()
override def readObject(): AnyRef = kryo.readClassAndObject(input)
}

private[serializer] class KryoOutputDataOutputBridge(output: KryoOutput) extends DataOutput {
/**
* This is a bridge class to wrap KryoOutput as an OutputStream and ObjectOutput. It forwards all
* methods of OutputStream and ObjectOutput to KryoOutput. It's usually helpful when an API expects
* an OutputStream or ObjectOutput but you want to use Kryo.
*/
private[spark] class KryoOutputObjectOutputBridge(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you put some docs on this class to explain what this does? Same for the above class.

kryo: Kryo, output: KryoOutput) extends FilterOutputStream(output) with ObjectOutput {
override def writeFloat(v: Float): Unit = output.writeFloat(v)
// There is no "readChars" counterpart, except maybe "readLine", which is not supported
override def writeChars(s: String): Unit = throw new UnsupportedOperationException("writeChars")
Expand All @@ -426,6 +439,7 @@ private[serializer] class KryoOutputDataOutputBridge(output: KryoOutput) extends
override def writeChar(v: Int): Unit = output.writeChar(v.toChar)
override def writeLong(v: Long): Unit = output.writeLong(v)
override def writeByte(v: Int): Unit = output.writeByte(v)
override def writeObject(obj: AnyRef): Unit = kryo.writeClassAndObject(output, obj)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should there be a new unit test in the KryoSerializerSuite to test this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -362,19 +362,35 @@ class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext {
bitmap.add(1)
bitmap.add(3)
bitmap.add(5)
bitmap.serialize(new KryoOutputDataOutputBridge(output))
// Ignore Kryo because it doesn't use writeObject
bitmap.serialize(new KryoOutputObjectOutputBridge(null, output))
output.flush()
output.close()

val inStream = new FileInputStream(tmpfile)
val input = new KryoInput(inStream)
val ret = new RoaringBitmap
ret.deserialize(new KryoInputDataInputBridge(input))
// Ignore Kryo because it doesn't use readObject
ret.deserialize(new KryoInputObjectInputBridge(null, input))
input.close()
assert(ret == bitmap)
Utils.deleteRecursively(dir)
}

test("KryoOutputObjectOutputBridge.writeObject and KryoInputObjectInputBridge.readObject") {
val kryo = new KryoSerializer(conf).newKryo()

val bytesOutput = new ByteArrayOutputStream()
val objectOutput = new KryoOutputObjectOutputBridge(kryo, new KryoOutput(bytesOutput))
objectOutput.writeObject("test")
objectOutput.close()

val bytesInput = new ByteArrayInputStream(bytesOutput.toByteArray)
val objectInput = new KryoInputObjectInputBridge(kryo, new KryoInput(bytesInput))
assert(objectInput.readObject() === "test")
objectInput.close()
}

test("getAutoReset") {
val ser = new KryoSerializer(new SparkConf).newInstance().asInstanceOf[KryoSerializerInstance]
assert(ser.getAutoReset)
Expand Down
4 changes: 4 additions & 0 deletions project/MimaExcludes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,10 @@ object MimaExcludes {
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.util.Vector"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.util.Vector$Multiplier"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.util.Vector$")
) ++ Seq(
// SPARK-12591 Register OpenHashMapBasedStateMap for Kryo
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.serializer.KryoInputDataInputBridge"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.serializer.KryoOutputDataOutputBridge")
) ++ Seq(
// SPARK-12510 Refactor ActorReceiver to support Java
ProblemFilters.exclude[AbstractClassProblem]("org.apache.spark.streaming.receiver.ActorReceiver")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,20 @@

package org.apache.spark.streaming.util

import java.io.{ObjectInputStream, ObjectOutputStream}
import java.io._

import scala.reflect.ClassTag

import com.esotericsoftware.kryo.{Kryo, KryoSerializable}
import com.esotericsoftware.kryo.io.{Input, Output}

import org.apache.spark.SparkConf
import org.apache.spark.serializer.{KryoOutputObjectOutputBridge, KryoInputObjectInputBridge}
import org.apache.spark.streaming.util.OpenHashMapBasedStateMap._
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed ClassTag because EmptyStateMap doesn't need it. If removing them, we don't need to add any codes for EmptyStateMap since it doesn't contain any field.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea.

import org.apache.spark.util.collection.OpenHashMap

/** Internal interface for defining the map that keeps track of sessions. */
private[streaming] abstract class StateMap[K: ClassTag, S: ClassTag] extends Serializable {
private[streaming] abstract class StateMap[K, S] extends Serializable {

/** Get the state for a key if it exists */
def get(key: K): Option[S]
Expand Down Expand Up @@ -54,7 +58,7 @@ private[streaming] abstract class StateMap[K: ClassTag, S: ClassTag] extends Ser

/** Companion object for [[StateMap]], with utility methods */
private[streaming] object StateMap {
def empty[K: ClassTag, S: ClassTag]: StateMap[K, S] = new EmptyStateMap[K, S]
def empty[K, S]: StateMap[K, S] = new EmptyStateMap[K, S]

def create[K: ClassTag, S: ClassTag](conf: SparkConf): StateMap[K, S] = {
val deltaChainThreshold = conf.getInt("spark.streaming.sessionByKey.deltaChainThreshold",
Expand All @@ -64,7 +68,7 @@ private[streaming] object StateMap {
}

/** Implementation of StateMap interface representing an empty map */
private[streaming] class EmptyStateMap[K: ClassTag, S: ClassTag] extends StateMap[K, S] {
private[streaming] class EmptyStateMap[K, S] extends StateMap[K, S] {
override def put(key: K, session: S, updateTime: Long): Unit = {
throw new NotImplementedError("put() should not be called on an EmptyStateMap")
}
Expand All @@ -77,21 +81,26 @@ private[streaming] class EmptyStateMap[K: ClassTag, S: ClassTag] extends StateMa
}

/** Implementation of StateMap based on Spark's [[org.apache.spark.util.collection.OpenHashMap]] */
private[streaming] class OpenHashMapBasedStateMap[K: ClassTag, S: ClassTag](
private[streaming] class OpenHashMapBasedStateMap[K, S](
@transient @volatile var parentStateMap: StateMap[K, S],
initialCapacity: Int = DEFAULT_INITIAL_CAPACITY,
deltaChainThreshold: Int = DELTA_CHAIN_LENGTH_THRESHOLD
) extends StateMap[K, S] { self =>
private var initialCapacity: Int = DEFAULT_INITIAL_CAPACITY,
private var deltaChainThreshold: Int = DELTA_CHAIN_LENGTH_THRESHOLD
)(implicit private var keyClassTag: ClassTag[K], private var stateClassTag: ClassTag[S])
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add keyClassTag and stateClassTag so that we can recover them.

extends StateMap[K, S] with KryoSerializable { self =>

def this(initialCapacity: Int, deltaChainThreshold: Int) = this(
def this(initialCapacity: Int, deltaChainThreshold: Int)
(implicit keyClassTag: ClassTag[K], stateClassTag: ClassTag[S]) = this(
new EmptyStateMap[K, S],
initialCapacity = initialCapacity,
deltaChainThreshold = deltaChainThreshold)

def this(deltaChainThreshold: Int) = this(
def this(deltaChainThreshold: Int)
(implicit keyClassTag: ClassTag[K], stateClassTag: ClassTag[S]) = this(
initialCapacity = DEFAULT_INITIAL_CAPACITY, deltaChainThreshold = deltaChainThreshold)

def this() = this(DELTA_CHAIN_LENGTH_THRESHOLD)
def this()(implicit keyClassTag: ClassTag[K], stateClassTag: ClassTag[S]) = {
this(DELTA_CHAIN_LENGTH_THRESHOLD)
}

require(initialCapacity >= 1, "Invalid initial capacity")
require(deltaChainThreshold >= 1, "Invalid delta chain threshold")
Expand Down Expand Up @@ -206,11 +215,7 @@ private[streaming] class OpenHashMapBasedStateMap[K: ClassTag, S: ClassTag](
* Serialize the map data. Besides serialization, this method actually compact the deltas
* (if needed) in a single pass over all the data in the map.
*/

private def writeObject(outputStream: ObjectOutputStream): Unit = {
// Write all the non-transient fields, especially class tags, etc.
outputStream.defaultWriteObject()

private def writeObjectInternal(outputStream: ObjectOutput): Unit = {
// Write the data in the delta of this state map
outputStream.writeInt(deltaMap.size)
val deltaMapIterator = deltaMap.iterator
Expand Down Expand Up @@ -262,11 +267,7 @@ private[streaming] class OpenHashMapBasedStateMap[K: ClassTag, S: ClassTag](
}

/** Deserialize the map data. */
private def readObject(inputStream: ObjectInputStream): Unit = {

// Read the non-transient fields, especially class tags, etc.
inputStream.defaultReadObject()

private def readObjectInternal(inputStream: ObjectInput): Unit = {
// Read the data of the delta
val deltaMapSize = inputStream.readInt()
deltaMap = if (deltaMapSize != 0) {
Expand Down Expand Up @@ -309,6 +310,34 @@ private[streaming] class OpenHashMapBasedStateMap[K: ClassTag, S: ClassTag](
}
parentStateMap = newParentSessionStore
}

private def writeObject(outputStream: ObjectOutputStream): Unit = {
// Write all the non-transient fields, especially class tags, etc.
outputStream.defaultWriteObject()
writeObjectInternal(outputStream)
}

private def readObject(inputStream: ObjectInputStream): Unit = {
// Read the non-transient fields, especially class tags, etc.
inputStream.defaultReadObject()
readObjectInternal(inputStream)
}

override def write(kryo: Kryo, output: Output): Unit = {
output.writeInt(initialCapacity)
output.writeInt(deltaChainThreshold)
kryo.writeClassAndObject(output, keyClassTag)
kryo.writeClassAndObject(output, stateClassTag)
writeObjectInternal(new KryoOutputObjectOutputBridge(kryo, output))
}

override def read(kryo: Kryo, input: Input): Unit = {
initialCapacity = input.readInt()
deltaChainThreshold = input.readInt()
keyClassTag = kryo.readClassAndObject(input).asInstanceOf[ClassTag[K]]
stateClassTag = kryo.readClassAndObject(input).asInstanceOf[ClassTag[S]]
readObjectInternal(new KryoInputObjectInputBridge(kryo, input))
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,23 @@

package org.apache.spark.streaming

import org.apache.spark.streaming.rdd.MapWithStateRDDRecord

import scala.collection.{immutable, mutable, Map}
import scala.reflect.ClassTag
import scala.util.Random

import org.apache.spark.SparkFunSuite
import com.esotericsoftware.kryo.{Kryo, KryoSerializable}
import com.esotericsoftware.kryo.io.{Output, Input}

import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.serializer._
import org.apache.spark.streaming.util.{EmptyStateMap, OpenHashMapBasedStateMap, StateMap}
import org.apache.spark.util.Utils

class StateMapSuite extends SparkFunSuite {

private val conf = new SparkConf()

test("EmptyStateMap") {
val map = new EmptyStateMap[Int, Int]
intercept[scala.NotImplementedError] {
Expand Down Expand Up @@ -128,17 +136,17 @@ class StateMapSuite extends SparkFunSuite {
map1.put(2, 200, 2)
testSerialization(map1, "error deserializing and serialized map with data + no delta")

val map2 = map1.copy()
val map2 = map1.copy().asInstanceOf[OpenHashMapBasedStateMap[Int, Int]]
// Do not test compaction
assert(map2.asInstanceOf[OpenHashMapBasedStateMap[_, _]].shouldCompact === false)
assert(map2.shouldCompact === false)
testSerialization(map2, "error deserializing and serialized map with 1 delta + no new data")

map2.put(3, 300, 3)
map2.put(4, 400, 4)
testSerialization(map2, "error deserializing and serialized map with 1 delta + new data")

val map3 = map2.copy()
assert(map3.asInstanceOf[OpenHashMapBasedStateMap[_, _]].shouldCompact === false)
val map3 = map2.copy().asInstanceOf[OpenHashMapBasedStateMap[Int, Int]]
assert(map3.shouldCompact === false)
testSerialization(map3, "error deserializing and serialized map with 2 delta + no new data")
map3.put(3, 600, 3)
map3.remove(2)
Expand Down Expand Up @@ -267,18 +275,25 @@ class StateMapSuite extends SparkFunSuite {
assertMap(stateMap, refMap.toMap, time, "Final state map does not match reference map")
}

private def testSerialization[MapType <: StateMap[Int, Int]](
map: MapType, msg: String): MapType = {
val deserMap = Utils.deserialize[MapType](
Utils.serialize(map), Thread.currentThread().getContextClassLoader)
private def testSerialization[T: ClassTag](
map: OpenHashMapBasedStateMap[T, T], msg: String): OpenHashMapBasedStateMap[T, T] = {
testSerialization(new JavaSerializer(conf), map, msg)
testSerialization(new KryoSerializer(conf), map, msg)
}

private def testSerialization[T : ClassTag](
serializer: Serializer,
map: OpenHashMapBasedStateMap[T, T],
msg: String): OpenHashMapBasedStateMap[T, T] = {
val deserMap = serializeAndDeserialize(serializer, map)
assertMap(deserMap, map, 1, msg)
deserMap
}

// Assert whether all the data and operations on a state map matches that of a reference state map
private def assertMap(
mapToTest: StateMap[Int, Int],
refMapToTestWith: StateMap[Int, Int],
private def assertMap[T](
mapToTest: StateMap[T, T],
refMapToTestWith: StateMap[T, T],
time: Long,
msg: String): Unit = {
withClue(msg) {
Expand Down Expand Up @@ -321,4 +336,59 @@ class StateMapSuite extends SparkFunSuite {
}
}
}

test("OpenHashMapBasedStateMap - serializing and deserializing with KryoSerializable states") {
val map = new OpenHashMapBasedStateMap[KryoState, KryoState]()
map.put(new KryoState("a"), new KryoState("b"), 1)
testSerialization(
new KryoSerializer(conf), map, "error deserializing and serialized KryoSerializable states")
}

test("EmptyStateMap - serializing and deserializing") {
val map = StateMap.empty[KryoState, KryoState]
// Since EmptyStateMap doesn't contains any date, KryoState won't break JavaSerializer.
assert(serializeAndDeserialize(new JavaSerializer(conf), map).
isInstanceOf[EmptyStateMap[KryoState, KryoState]])
assert(serializeAndDeserialize(new KryoSerializer(conf), map).
isInstanceOf[EmptyStateMap[KryoState, KryoState]])
}

test("MapWithStateRDDRecord - serializing and deserializing with KryoSerializable states") {
val map = new OpenHashMapBasedStateMap[KryoState, KryoState]()
map.put(new KryoState("a"), new KryoState("b"), 1)

val record =
MapWithStateRDDRecord[KryoState, KryoState, KryoState](map, Seq(new KryoState("c")))
val deserRecord = serializeAndDeserialize(new KryoSerializer(conf), record)
assert(!(record eq deserRecord))
assert(record.stateMap.getAll().toSeq === deserRecord.stateMap.getAll().toSeq)
assert(record.mappedData === deserRecord.mappedData)
}

private def serializeAndDeserialize[T: ClassTag](serializer: Serializer, t: T): T = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this need the ClassTag

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. deserialize needs the ClassTag

val serializerInstance = serializer.newInstance()
serializerInstance.deserialize[T](
serializerInstance.serialize(t), Thread.currentThread().getContextClassLoader)
}
}

/** A class that only supports Kryo serialization. */
private[streaming] final class KryoState(var state: String) extends KryoSerializable {

override def write(kryo: Kryo, output: Output): Unit = {
kryo.writeClassAndObject(output, state)
}

override def read(kryo: Kryo, input: Input): Unit = {
state = kryo.readClassAndObject(input).asInstanceOf[String]
}

override def equals(other: Any): Boolean = other match {
case that: KryoState => state == that.state
case _ => false
}

override def hashCode(): Int = {
if (state == null) 0 else state.hashCode()
}
}