-
Notifications
You must be signed in to change notification settings - Fork 28.9k
[SPARK-12591][Streaming]Register OpenHashMapBasedStateMap for Kryo #10609
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
bf5632e
39d3008
a65ab45
bf0892c
33368be
4e4e9a1
ee452fe
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,7 +17,7 @@ | |
|
|
||
| package org.apache.spark.serializer | ||
|
|
||
| import java.io.{DataInput, DataOutput, EOFException, InputStream, IOException, OutputStream} | ||
| import java.io._ | ||
| import java.nio.ByteBuffer | ||
| import javax.annotation.Nullable | ||
|
|
||
|
|
@@ -378,18 +378,24 @@ private[serializer] object KryoSerializer { | |
| private val toRegisterSerializer = Map[Class[_], KryoClassSerializer[_]]( | ||
| classOf[RoaringBitmap] -> new KryoClassSerializer[RoaringBitmap]() { | ||
| override def write(kryo: Kryo, output: KryoOutput, bitmap: RoaringBitmap): Unit = { | ||
| bitmap.serialize(new KryoOutputDataOutputBridge(output)) | ||
| bitmap.serialize(new KryoOutputObjectOutputBridge(kryo, output)) | ||
| } | ||
| override def read(kryo: Kryo, input: KryoInput, cls: Class[RoaringBitmap]): RoaringBitmap = { | ||
| val ret = new RoaringBitmap | ||
| ret.deserialize(new KryoInputDataInputBridge(input)) | ||
| ret.deserialize(new KryoInputObjectInputBridge(kryo, input)) | ||
| ret | ||
| } | ||
| } | ||
| ) | ||
| } | ||
|
|
||
| private[serializer] class KryoInputDataInputBridge(input: KryoInput) extends DataInput { | ||
| /** | ||
| * This is a bridge class to wrap KryoInput as an InputStream and ObjectInput. It forwards all | ||
| * methods of InputStream and ObjectInput to KryoInput. It's usually helpful when an API expects | ||
| * an InputStream or ObjectInput but you want to use Kryo. | ||
| */ | ||
| private[spark] class KryoInputObjectInputBridge( | ||
| kryo: Kryo, input: KryoInput) extends FilterInputStream(input) with ObjectInput { | ||
| override def readLong(): Long = input.readLong() | ||
| override def readChar(): Char = input.readChar() | ||
| override def readFloat(): Float = input.readFloat() | ||
|
|
@@ -408,9 +414,16 @@ private[serializer] class KryoInputDataInputBridge(input: KryoInput) extends Dat | |
| override def readBoolean(): Boolean = input.readBoolean() | ||
| override def readUnsignedByte(): Int = input.readByteUnsigned() | ||
| override def readDouble(): Double = input.readDouble() | ||
| override def readObject(): AnyRef = kryo.readClassAndObject(input) | ||
| } | ||
|
|
||
| private[serializer] class KryoOutputDataOutputBridge(output: KryoOutput) extends DataOutput { | ||
| /** | ||
| * This is a bridge class to wrap KryoOutput as an OutputStream and ObjectOutput. It forwards all | ||
| * methods of OutputStream and ObjectOutput to KryoOutput. It's usually helpful when an API expects | ||
| * an OutputStream or ObjectOutput but you want to use Kryo. | ||
| */ | ||
| private[spark] class KryoOutputObjectOutputBridge( | ||
| kryo: Kryo, output: KryoOutput) extends FilterOutputStream(output) with ObjectOutput { | ||
| override def writeFloat(v: Float): Unit = output.writeFloat(v) | ||
| // There is no "readChars" counterpart, except maybe "readLine", which is not supported | ||
| override def writeChars(s: String): Unit = throw new UnsupportedOperationException("writeChars") | ||
|
|
@@ -426,6 +439,7 @@ private[serializer] class KryoOutputDataOutputBridge(output: KryoOutput) extends | |
| override def writeChar(v: Int): Unit = output.writeChar(v.toChar) | ||
| override def writeLong(v: Long): Unit = output.writeLong(v) | ||
| override def writeByte(v: Int): Unit = output.writeByte(v) | ||
| override def writeObject(obj: AnyRef): Unit = kryo.writeClassAndObject(output, obj) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should there be a new unit test in the KryoSerializerSuite to test this?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
| } | ||
|
|
||
| /** | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,16 +17,20 @@ | |
|
|
||
| package org.apache.spark.streaming.util | ||
|
|
||
| import java.io.{ObjectInputStream, ObjectOutputStream} | ||
| import java.io._ | ||
|
|
||
| import scala.reflect.ClassTag | ||
|
|
||
| import com.esotericsoftware.kryo.{Kryo, KryoSerializable} | ||
| import com.esotericsoftware.kryo.io.{Input, Output} | ||
|
|
||
| import org.apache.spark.SparkConf | ||
| import org.apache.spark.serializer.{KryoOutputObjectOutputBridge, KryoInputObjectInputBridge} | ||
| import org.apache.spark.streaming.util.OpenHashMapBasedStateMap._ | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I removed ClassTag because EmptyStateMap doesn't need it. If removing them, we don't need to add any codes for
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good idea. |
||
| import org.apache.spark.util.collection.OpenHashMap | ||
|
|
||
| /** Internal interface for defining the map that keeps track of sessions. */ | ||
| private[streaming] abstract class StateMap[K: ClassTag, S: ClassTag] extends Serializable { | ||
| private[streaming] abstract class StateMap[K, S] extends Serializable { | ||
|
|
||
| /** Get the state for a key if it exists */ | ||
| def get(key: K): Option[S] | ||
|
|
@@ -54,7 +58,7 @@ private[streaming] abstract class StateMap[K: ClassTag, S: ClassTag] extends Ser | |
|
|
||
| /** Companion object for [[StateMap]], with utility methods */ | ||
| private[streaming] object StateMap { | ||
| def empty[K: ClassTag, S: ClassTag]: StateMap[K, S] = new EmptyStateMap[K, S] | ||
| def empty[K, S]: StateMap[K, S] = new EmptyStateMap[K, S] | ||
|
|
||
| def create[K: ClassTag, S: ClassTag](conf: SparkConf): StateMap[K, S] = { | ||
| val deltaChainThreshold = conf.getInt("spark.streaming.sessionByKey.deltaChainThreshold", | ||
|
|
@@ -64,7 +68,7 @@ private[streaming] object StateMap { | |
| } | ||
|
|
||
| /** Implementation of StateMap interface representing an empty map */ | ||
| private[streaming] class EmptyStateMap[K: ClassTag, S: ClassTag] extends StateMap[K, S] { | ||
| private[streaming] class EmptyStateMap[K, S] extends StateMap[K, S] { | ||
| override def put(key: K, session: S, updateTime: Long): Unit = { | ||
| throw new NotImplementedError("put() should not be called on an EmptyStateMap") | ||
| } | ||
|
|
@@ -77,21 +81,26 @@ private[streaming] class EmptyStateMap[K: ClassTag, S: ClassTag] extends StateMa | |
| } | ||
|
|
||
| /** Implementation of StateMap based on Spark's [[org.apache.spark.util.collection.OpenHashMap]] */ | ||
| private[streaming] class OpenHashMapBasedStateMap[K: ClassTag, S: ClassTag]( | ||
| private[streaming] class OpenHashMapBasedStateMap[K, S]( | ||
| @transient @volatile var parentStateMap: StateMap[K, S], | ||
| initialCapacity: Int = DEFAULT_INITIAL_CAPACITY, | ||
| deltaChainThreshold: Int = DELTA_CHAIN_LENGTH_THRESHOLD | ||
| ) extends StateMap[K, S] { self => | ||
| private var initialCapacity: Int = DEFAULT_INITIAL_CAPACITY, | ||
| private var deltaChainThreshold: Int = DELTA_CHAIN_LENGTH_THRESHOLD | ||
| )(implicit private var keyClassTag: ClassTag[K], private var stateClassTag: ClassTag[S]) | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add |
||
| extends StateMap[K, S] with KryoSerializable { self => | ||
|
|
||
| def this(initialCapacity: Int, deltaChainThreshold: Int) = this( | ||
| def this(initialCapacity: Int, deltaChainThreshold: Int) | ||
| (implicit keyClassTag: ClassTag[K], stateClassTag: ClassTag[S]) = this( | ||
| new EmptyStateMap[K, S], | ||
| initialCapacity = initialCapacity, | ||
| deltaChainThreshold = deltaChainThreshold) | ||
|
|
||
| def this(deltaChainThreshold: Int) = this( | ||
| def this(deltaChainThreshold: Int) | ||
| (implicit keyClassTag: ClassTag[K], stateClassTag: ClassTag[S]) = this( | ||
| initialCapacity = DEFAULT_INITIAL_CAPACITY, deltaChainThreshold = deltaChainThreshold) | ||
|
|
||
| def this() = this(DELTA_CHAIN_LENGTH_THRESHOLD) | ||
| def this()(implicit keyClassTag: ClassTag[K], stateClassTag: ClassTag[S]) = { | ||
| this(DELTA_CHAIN_LENGTH_THRESHOLD) | ||
| } | ||
|
|
||
| require(initialCapacity >= 1, "Invalid initial capacity") | ||
| require(deltaChainThreshold >= 1, "Invalid delta chain threshold") | ||
|
|
@@ -206,11 +215,7 @@ private[streaming] class OpenHashMapBasedStateMap[K: ClassTag, S: ClassTag]( | |
| * Serialize the map data. Besides serialization, this method actually compact the deltas | ||
| * (if needed) in a single pass over all the data in the map. | ||
| */ | ||
|
|
||
| private def writeObject(outputStream: ObjectOutputStream): Unit = { | ||
| // Write all the non-transient fields, especially class tags, etc. | ||
| outputStream.defaultWriteObject() | ||
|
|
||
| private def writeObjectInternal(outputStream: ObjectOutput): Unit = { | ||
| // Write the data in the delta of this state map | ||
| outputStream.writeInt(deltaMap.size) | ||
| val deltaMapIterator = deltaMap.iterator | ||
|
|
@@ -262,11 +267,7 @@ private[streaming] class OpenHashMapBasedStateMap[K: ClassTag, S: ClassTag]( | |
| } | ||
|
|
||
| /** Deserialize the map data. */ | ||
| private def readObject(inputStream: ObjectInputStream): Unit = { | ||
|
|
||
| // Read the non-transient fields, especially class tags, etc. | ||
| inputStream.defaultReadObject() | ||
|
|
||
| private def readObjectInternal(inputStream: ObjectInput): Unit = { | ||
| // Read the data of the delta | ||
| val deltaMapSize = inputStream.readInt() | ||
| deltaMap = if (deltaMapSize != 0) { | ||
|
|
@@ -309,6 +310,34 @@ private[streaming] class OpenHashMapBasedStateMap[K: ClassTag, S: ClassTag]( | |
| } | ||
| parentStateMap = newParentSessionStore | ||
| } | ||
|
|
||
| private def writeObject(outputStream: ObjectOutputStream): Unit = { | ||
| // Write all the non-transient fields, especially class tags, etc. | ||
| outputStream.defaultWriteObject() | ||
| writeObjectInternal(outputStream) | ||
| } | ||
|
|
||
| private def readObject(inputStream: ObjectInputStream): Unit = { | ||
| // Read the non-transient fields, especially class tags, etc. | ||
| inputStream.defaultReadObject() | ||
| readObjectInternal(inputStream) | ||
| } | ||
|
|
||
| override def write(kryo: Kryo, output: Output): Unit = { | ||
| output.writeInt(initialCapacity) | ||
| output.writeInt(deltaChainThreshold) | ||
| kryo.writeClassAndObject(output, keyClassTag) | ||
| kryo.writeClassAndObject(output, stateClassTag) | ||
| writeObjectInternal(new KryoOutputObjectOutputBridge(kryo, output)) | ||
| } | ||
|
|
||
| override def read(kryo: Kryo, input: Input): Unit = { | ||
| initialCapacity = input.readInt() | ||
| deltaChainThreshold = input.readInt() | ||
| keyClassTag = kryo.readClassAndObject(input).asInstanceOf[ClassTag[K]] | ||
| stateClassTag = kryo.readClassAndObject(input).asInstanceOf[ClassTag[S]] | ||
| readObjectInternal(new KryoInputObjectInputBridge(kryo, input)) | ||
| } | ||
| } | ||
|
|
||
| /** | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,15 +17,23 @@ | |
|
|
||
| package org.apache.spark.streaming | ||
|
|
||
| import org.apache.spark.streaming.rdd.MapWithStateRDDRecord | ||
|
|
||
| import scala.collection.{immutable, mutable, Map} | ||
| import scala.reflect.ClassTag | ||
| import scala.util.Random | ||
|
|
||
| import org.apache.spark.SparkFunSuite | ||
| import com.esotericsoftware.kryo.{Kryo, KryoSerializable} | ||
| import com.esotericsoftware.kryo.io.{Output, Input} | ||
|
|
||
| import org.apache.spark.{SparkConf, SparkFunSuite} | ||
| import org.apache.spark.serializer._ | ||
| import org.apache.spark.streaming.util.{EmptyStateMap, OpenHashMapBasedStateMap, StateMap} | ||
| import org.apache.spark.util.Utils | ||
|
|
||
| class StateMapSuite extends SparkFunSuite { | ||
|
|
||
| private val conf = new SparkConf() | ||
|
|
||
| test("EmptyStateMap") { | ||
| val map = new EmptyStateMap[Int, Int] | ||
| intercept[scala.NotImplementedError] { | ||
|
|
@@ -128,17 +136,17 @@ class StateMapSuite extends SparkFunSuite { | |
| map1.put(2, 200, 2) | ||
| testSerialization(map1, "error deserializing and serialized map with data + no delta") | ||
|
|
||
| val map2 = map1.copy() | ||
| val map2 = map1.copy().asInstanceOf[OpenHashMapBasedStateMap[Int, Int]] | ||
| // Do not test compaction | ||
| assert(map2.asInstanceOf[OpenHashMapBasedStateMap[_, _]].shouldCompact === false) | ||
| assert(map2.shouldCompact === false) | ||
| testSerialization(map2, "error deserializing and serialized map with 1 delta + no new data") | ||
|
|
||
| map2.put(3, 300, 3) | ||
| map2.put(4, 400, 4) | ||
| testSerialization(map2, "error deserializing and serialized map with 1 delta + new data") | ||
|
|
||
| val map3 = map2.copy() | ||
| assert(map3.asInstanceOf[OpenHashMapBasedStateMap[_, _]].shouldCompact === false) | ||
| val map3 = map2.copy().asInstanceOf[OpenHashMapBasedStateMap[Int, Int]] | ||
| assert(map3.shouldCompact === false) | ||
| testSerialization(map3, "error deserializing and serialized map with 2 delta + no new data") | ||
| map3.put(3, 600, 3) | ||
| map3.remove(2) | ||
|
|
@@ -267,18 +275,25 @@ class StateMapSuite extends SparkFunSuite { | |
| assertMap(stateMap, refMap.toMap, time, "Final state map does not match reference map") | ||
| } | ||
|
|
||
| private def testSerialization[MapType <: StateMap[Int, Int]]( | ||
| map: MapType, msg: String): MapType = { | ||
| val deserMap = Utils.deserialize[MapType]( | ||
| Utils.serialize(map), Thread.currentThread().getContextClassLoader) | ||
| private def testSerialization[T: ClassTag]( | ||
| map: OpenHashMapBasedStateMap[T, T], msg: String): OpenHashMapBasedStateMap[T, T] = { | ||
| testSerialization(new JavaSerializer(conf), map, msg) | ||
| testSerialization(new KryoSerializer(conf), map, msg) | ||
| } | ||
|
|
||
| private def testSerialization[T : ClassTag]( | ||
| serializer: Serializer, | ||
| map: OpenHashMapBasedStateMap[T, T], | ||
| msg: String): OpenHashMapBasedStateMap[T, T] = { | ||
| val deserMap = serializeAndDeserialize(serializer, map) | ||
| assertMap(deserMap, map, 1, msg) | ||
| deserMap | ||
| } | ||
|
|
||
| // Assert whether all the data and operations on a state map matches that of a reference state map | ||
| private def assertMap( | ||
| mapToTest: StateMap[Int, Int], | ||
| refMapToTestWith: StateMap[Int, Int], | ||
| private def assertMap[T]( | ||
| mapToTest: StateMap[T, T], | ||
| refMapToTestWith: StateMap[T, T], | ||
| time: Long, | ||
| msg: String): Unit = { | ||
| withClue(msg) { | ||
|
|
@@ -321,4 +336,59 @@ class StateMapSuite extends SparkFunSuite { | |
| } | ||
| } | ||
| } | ||
|
|
||
| test("OpenHashMapBasedStateMap - serializing and deserializing with KryoSerializable states") { | ||
| val map = new OpenHashMapBasedStateMap[KryoState, KryoState]() | ||
| map.put(new KryoState("a"), new KryoState("b"), 1) | ||
| testSerialization( | ||
| new KryoSerializer(conf), map, "error deserializing and serialized KryoSerializable states") | ||
| } | ||
|
|
||
| test("EmptyStateMap - serializing and deserializing") { | ||
| val map = StateMap.empty[KryoState, KryoState] | ||
| // Since EmptyStateMap doesn't contains any date, KryoState won't break JavaSerializer. | ||
| assert(serializeAndDeserialize(new JavaSerializer(conf), map). | ||
| isInstanceOf[EmptyStateMap[KryoState, KryoState]]) | ||
| assert(serializeAndDeserialize(new KryoSerializer(conf), map). | ||
| isInstanceOf[EmptyStateMap[KryoState, KryoState]]) | ||
| } | ||
|
|
||
| test("MapWithStateRDDRecord - serializing and deserializing with KryoSerializable states") { | ||
| val map = new OpenHashMapBasedStateMap[KryoState, KryoState]() | ||
| map.put(new KryoState("a"), new KryoState("b"), 1) | ||
|
|
||
| val record = | ||
| MapWithStateRDDRecord[KryoState, KryoState, KryoState](map, Seq(new KryoState("c"))) | ||
| val deserRecord = serializeAndDeserialize(new KryoSerializer(conf), record) | ||
| assert(!(record eq deserRecord)) | ||
| assert(record.stateMap.getAll().toSeq === deserRecord.stateMap.getAll().toSeq) | ||
| assert(record.mappedData === deserRecord.mappedData) | ||
| } | ||
|
|
||
| private def serializeAndDeserialize[T: ClassTag](serializer: Serializer, t: T): T = { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this need the ClassTag
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes. |
||
| val serializerInstance = serializer.newInstance() | ||
| serializerInstance.deserialize[T]( | ||
| serializerInstance.serialize(t), Thread.currentThread().getContextClassLoader) | ||
| } | ||
| } | ||
|
|
||
| /** A class that only supports Kryo serialization. */ | ||
| private[streaming] final class KryoState(var state: String) extends KryoSerializable { | ||
|
|
||
| override def write(kryo: Kryo, output: Output): Unit = { | ||
| kryo.writeClassAndObject(output, state) | ||
| } | ||
|
|
||
| override def read(kryo: Kryo, input: Input): Unit = { | ||
| state = kryo.readClassAndObject(input).asInstanceOf[String] | ||
| } | ||
|
|
||
| override def equals(other: Any): Boolean = other match { | ||
| case that: KryoState => state == that.state | ||
| case _ => false | ||
| } | ||
|
|
||
| override def hashCode(): Int = { | ||
| if (state == null) 0 else state.hashCode() | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you put some docs on this class to explain what this does? Same for the above class.