Skip to content

Commit abcafcf

Browse files
zsxwingrxin
authored andcommitted
[Spark 3922] Refactor spark-core to use Utils.UTF_8
A global UTF8 constant is very helpful to handle encoding problems when converting between String and bytes. There are several solutions here: 1. Add `val UTF_8 = Charset.forName("UTF-8")` to Utils.scala 2. java.nio.charset.StandardCharsets.UTF_8 (require JDK7) 3. io.netty.util.CharsetUtil.UTF_8 4. com.google.common.base.Charsets.UTF_8 5. org.apache.commons.lang.CharEncoding.UTF_8 6. org.apache.commons.lang3.CharEncoding.UTF_8 IMO, I prefer option 1) because people can find it easily. This is a PR for option 1) and only fixes Spark Core. Author: zsxwing <[email protected]> Closes #2781 from zsxwing/SPARK-3922 and squashes the following commits: f974edd [zsxwing] Merge branch 'master' into SPARK-3922 2d27423 [zsxwing] Refactor spark-core to use Refactor spark-core to use Utils.UTF_8
1 parent 47a40f6 commit abcafcf

File tree

16 files changed

+55
-46
lines changed

16 files changed

+55
-46
lines changed

core/src/main/scala/org/apache/spark/SparkSaslClient.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
package org.apache.spark
1919

20-
import java.io.IOException
2120
import javax.security.auth.callback.Callback
2221
import javax.security.auth.callback.CallbackHandler
2322
import javax.security.auth.callback.NameCallback
@@ -31,6 +30,8 @@ import javax.security.sasl.SaslException
3130

3231
import scala.collection.JavaConversions.mapAsJavaMap
3332

33+
import com.google.common.base.Charsets.UTF_8
34+
3435
/**
3536
* Implements SASL Client logic for Spark
3637
*/
@@ -111,10 +112,10 @@ private[spark] class SparkSaslClient(securityMgr: SecurityManager) extends Logg
111112
CallbackHandler {
112113

113114
private val userName: String =
114-
SparkSaslServer.encodeIdentifier(securityMgr.getSaslUser().getBytes("utf-8"))
115+
SparkSaslServer.encodeIdentifier(securityMgr.getSaslUser().getBytes(UTF_8))
115116
private val secretKey = securityMgr.getSecretKey()
116117
private val userPassword: Array[Char] = SparkSaslServer.encodePassword(
117-
if (secretKey != null) secretKey.getBytes("utf-8") else "".getBytes("utf-8"))
118+
if (secretKey != null) secretKey.getBytes(UTF_8) else "".getBytes(UTF_8))
118119

119120
/**
120121
* Implementation used to respond to SASL request from the server.

core/src/main/scala/org/apache/spark/SparkSaslServer.scala

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ import javax.security.sasl.Sasl
2828
import javax.security.sasl.SaslException
2929
import javax.security.sasl.SaslServer
3030
import scala.collection.JavaConversions.mapAsJavaMap
31+
32+
import com.google.common.base.Charsets.UTF_8
3133
import org.apache.commons.net.util.Base64
3234

3335
/**
@@ -89,7 +91,7 @@ private[spark] class SparkSaslServer(securityMgr: SecurityManager) extends Loggi
8991
extends CallbackHandler {
9092

9193
private val userName: String =
92-
SparkSaslServer.encodeIdentifier(securityMgr.getSaslUser().getBytes("utf-8"))
94+
SparkSaslServer.encodeIdentifier(securityMgr.getSaslUser().getBytes(UTF_8))
9395

9496
override def handle(callbacks: Array[Callback]) {
9597
logDebug("In the sasl server callback handler")
@@ -101,7 +103,7 @@ private[spark] class SparkSaslServer(securityMgr: SecurityManager) extends Loggi
101103
case pc: PasswordCallback => {
102104
logDebug("handle: SASL server callback: setting userPassword")
103105
val password: Array[Char] =
104-
SparkSaslServer.encodePassword(securityMgr.getSecretKey().getBytes("utf-8"))
106+
SparkSaslServer.encodePassword(securityMgr.getSecretKey().getBytes(UTF_8))
105107
pc.setPassword(password)
106108
}
107109
case rc: RealmCallback => {
@@ -159,7 +161,7 @@ private[spark] object SparkSaslServer {
159161
* @return Base64-encoded string
160162
*/
161163
def encodeIdentifier(identifier: Array[Byte]): String = {
162-
new String(Base64.encodeBase64(identifier), "utf-8")
164+
new String(Base64.encodeBase64(identifier), UTF_8)
163165
}
164166

165167
/**
@@ -168,7 +170,7 @@ private[spark] object SparkSaslServer {
168170
* @return password as a char array.
169171
*/
170172
def encodePassword(password: Array[Byte]): Array[Char] = {
171-
new String(Base64.encodeBase64(password), "utf-8").toCharArray()
173+
new String(Base64.encodeBase64(password), UTF_8).toCharArray()
172174
}
173175
}
174176

core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,14 @@ package org.apache.spark.api.python
1919

2020
import java.io._
2121
import java.net._
22-
import java.nio.charset.Charset
2322
import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, Collections}
2423

2524
import scala.collection.JavaConversions._
2625
import scala.collection.JavaConverters._
2726
import scala.collection.mutable
2827
import scala.language.existentials
2928

29+
import com.google.common.base.Charsets.UTF_8
3030
import net.razorvine.pickle.{Pickler, Unpickler}
3131

3232
import org.apache.hadoop.conf.Configuration
@@ -134,7 +134,7 @@ private[spark] class PythonRDD(
134134
val exLength = stream.readInt()
135135
val obj = new Array[Byte](exLength)
136136
stream.readFully(obj)
137-
throw new PythonException(new String(obj, "utf-8"),
137+
throw new PythonException(new String(obj, UTF_8),
138138
writerThread.exception.getOrElse(null))
139139
case SpecialLengths.END_OF_DATA_SECTION =>
140140
// We've finished the data section of the output, but we can still
@@ -318,7 +318,6 @@ private object SpecialLengths {
318318
}
319319

320320
private[spark] object PythonRDD extends Logging {
321-
val UTF8 = Charset.forName("UTF-8")
322321

323322
// remember the broadcasts sent to each worker
324323
private val workerBroadcasts = new mutable.WeakHashMap[Socket, mutable.Set[Long]]()
@@ -586,7 +585,7 @@ private[spark] object PythonRDD extends Logging {
586585
}
587586

588587
def writeUTF(str: String, dataOut: DataOutputStream) {
589-
val bytes = str.getBytes(UTF8)
588+
val bytes = str.getBytes(UTF_8)
590589
dataOut.writeInt(bytes.length)
591590
dataOut.write(bytes)
592591
}
@@ -849,7 +848,7 @@ private[spark] object PythonRDD extends Logging {
849848

850849
private
851850
class BytesToString extends org.apache.spark.api.java.function.Function[Array[Byte], String] {
852-
override def call(arr: Array[Byte]) : String = new String(arr, PythonRDD.UTF8)
851+
override def call(arr: Array[Byte]) : String = new String(arr, UTF_8)
853852
}
854853

855854
/**

core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818
package org.apache.spark.api.python
1919

2020
import java.io.{DataOutput, DataInput}
21-
import java.nio.charset.Charset
21+
22+
import com.google.common.base.Charsets.UTF_8
2223

2324
import org.apache.hadoop.io._
2425
import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat
@@ -136,7 +137,7 @@ object WriteInputFormatTestDataGenerator {
136137
sc.parallelize(intKeys).saveAsSequenceFile(intPath)
137138
sc.parallelize(intKeys.map{ case (k, v) => (k.toDouble, v) }).saveAsSequenceFile(doublePath)
138139
sc.parallelize(intKeys.map{ case (k, v) => (k.toString, v) }).saveAsSequenceFile(textPath)
139-
sc.parallelize(intKeys.map{ case (k, v) => (k, v.getBytes(Charset.forName("UTF-8"))) }
140+
sc.parallelize(intKeys.map{ case (k, v) => (k, v.getBytes(UTF_8)) }
140141
).saveAsSequenceFile(bytesPath)
141142
val bools = Seq((1, true), (2, true), (2, false), (3, true), (2, false), (1, false))
142143
sc.parallelize(bools).saveAsSequenceFile(boolPath)

core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ import scala.collection.JavaConversions._
2323
import scala.collection.Map
2424

2525
import akka.actor.ActorRef
26-
import com.google.common.base.Charsets
26+
import com.google.common.base.Charsets.UTF_8
2727
import com.google.common.io.Files
2828
import org.apache.hadoop.conf.Configuration
2929
import org.apache.hadoop.fs.{FileUtil, Path}
@@ -178,7 +178,7 @@ private[spark] class DriverRunner(
178178
val stderr = new File(baseDir, "stderr")
179179
val header = "Launch Command: %s\n%s\n\n".format(
180180
command.mkString("\"", "\" \"", "\""), "=" * 40)
181-
Files.append(header, stderr, Charsets.UTF_8)
181+
Files.append(header, stderr, UTF_8)
182182
CommandUtils.redirectStream(process.getErrorStream, stderr)
183183
}
184184
runCommandWithRetry(ProcessBuilderLike(builder), initialize, supervise)

core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ package org.apache.spark.deploy.worker
2020
import java.io._
2121

2222
import akka.actor.ActorRef
23-
import com.google.common.base.Charsets
23+
import com.google.common.base.Charsets.UTF_8
2424
import com.google.common.io.Files
2525

2626
import org.apache.spark.{SparkConf, Logging}
@@ -151,7 +151,7 @@ private[spark] class ExecutorRunner(
151151
stdoutAppender = FileAppender(process.getInputStream, stdout, conf)
152152

153153
val stderr = new File(executorDir, "stderr")
154-
Files.write(header, stderr, Charsets.UTF_8)
154+
Files.write(header, stderr, UTF_8)
155155
stderrAppender = FileAppender(process.getErrorStream, stderr, conf)
156156

157157
state = ExecutorState.RUNNING

core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClient.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,13 @@ package org.apache.spark.network.netty.client
1919

2020
import java.util.concurrent.TimeoutException
2121

22+
import com.google.common.base.Charsets.UTF_8
2223
import io.netty.bootstrap.Bootstrap
2324
import io.netty.buffer.PooledByteBufAllocator
2425
import io.netty.channel.socket.SocketChannel
2526
import io.netty.channel.{ChannelFutureListener, ChannelFuture, ChannelInitializer, ChannelOption}
2627
import io.netty.handler.codec.LengthFieldBasedFrameDecoder
2728
import io.netty.handler.codec.string.StringEncoder
28-
import io.netty.util.CharsetUtil
2929

3030
import org.apache.spark.Logging
3131

@@ -61,7 +61,7 @@ class BlockFetchingClient(factory: BlockFetchingClientFactory, hostname: String,
6161
b.handler(new ChannelInitializer[SocketChannel] {
6262
override def initChannel(ch: SocketChannel): Unit = {
6363
ch.pipeline
64-
.addLast("encoder", new StringEncoder(CharsetUtil.UTF_8))
64+
.addLast("encoder", new StringEncoder(UTF_8))
6565
// maxFrameLength = 2G, lengthFieldOffset = 0, lengthFieldLength = 4
6666
.addLast("framedLengthDecoder", new LengthFieldBasedFrameDecoder(Int.MaxValue, 0, 4))
6767
.addLast("handler", handler)

core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClientHandler.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
package org.apache.spark.network.netty.client
1919

20+
import com.google.common.base.Charsets.UTF_8
2021
import io.netty.buffer.ByteBuf
2122
import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler}
2223

@@ -67,7 +68,7 @@ class BlockFetchingClientHandler extends SimpleChannelInboundHandler[ByteBuf] wi
6768
val blockIdLen = in.readInt()
6869
val blockIdBytes = new Array[Byte](math.abs(blockIdLen))
6970
in.readBytes(blockIdBytes)
70-
val blockId = new String(blockIdBytes)
71+
val blockId = new String(blockIdBytes, UTF_8)
7172
val blockSize = totalLen - math.abs(blockIdLen) - 4
7273

7374
def server = ctx.channel.remoteAddress.toString
@@ -76,7 +77,7 @@ class BlockFetchingClientHandler extends SimpleChannelInboundHandler[ByteBuf] wi
7677
if (blockIdLen < 0) {
7778
val errorMessageBytes = new Array[Byte](blockSize)
7879
in.readBytes(errorMessageBytes)
79-
val errorMsg = new String(errorMessageBytes)
80+
val errorMsg = new String(errorMessageBytes, UTF_8)
8081
logTrace(s"Received block $blockId ($blockSize B) with error $errorMsg from $server")
8182

8283
val listener = outstandingRequests.get(blockId)

core/src/main/scala/org/apache/spark/network/netty/server/BlockServer.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package org.apache.spark.network.netty.server
1919

2020
import java.net.InetSocketAddress
2121

22+
import com.google.common.base.Charsets.UTF_8
2223
import io.netty.bootstrap.ServerBootstrap
2324
import io.netty.buffer.PooledByteBufAllocator
2425
import io.netty.channel.{ChannelFuture, ChannelInitializer, ChannelOption}
@@ -30,7 +31,6 @@ import io.netty.channel.socket.nio.NioServerSocketChannel
3031
import io.netty.channel.socket.oio.OioServerSocketChannel
3132
import io.netty.handler.codec.LineBasedFrameDecoder
3233
import io.netty.handler.codec.string.StringDecoder
33-
import io.netty.util.CharsetUtil
3434

3535
import org.apache.spark.{Logging, SparkConf}
3636
import org.apache.spark.network.netty.NettyConfig
@@ -131,7 +131,7 @@ class BlockServer(conf: NettyConfig, dataProvider: BlockDataProvider) extends Lo
131131
override def initChannel(ch: SocketChannel): Unit = {
132132
ch.pipeline
133133
.addLast("frameDecoder", new LineBasedFrameDecoder(1024)) // max block id length 1024
134-
.addLast("stringDecoder", new StringDecoder(CharsetUtil.UTF_8))
134+
.addLast("stringDecoder", new StringDecoder(UTF_8))
135135
.addLast("blockHeaderEncoder", new BlockHeaderEncoder)
136136
.addLast("handler", new BlockServerHandler(dataProvider))
137137
}

core/src/main/scala/org/apache/spark/network/netty/server/BlockServerChannelInitializer.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,13 @@
1717

1818
package org.apache.spark.network.netty.server
1919

20+
import com.google.common.base.Charsets.UTF_8
2021
import io.netty.channel.ChannelInitializer
2122
import io.netty.channel.socket.SocketChannel
2223
import io.netty.handler.codec.LineBasedFrameDecoder
2324
import io.netty.handler.codec.string.StringDecoder
24-
import io.netty.util.CharsetUtil
25-
import org.apache.spark.storage.BlockDataProvider
2625

26+
import org.apache.spark.storage.BlockDataProvider
2727

2828
/** Channel initializer that sets up the pipeline for the BlockServer. */
2929
private[netty]
@@ -33,7 +33,7 @@ class BlockServerChannelInitializer(dataProvider: BlockDataProvider)
3333
override def initChannel(ch: SocketChannel): Unit = {
3434
ch.pipeline
3535
.addLast("frameDecoder", new LineBasedFrameDecoder(1024)) // max block id length 1024
36-
.addLast("stringDecoder", new StringDecoder(CharsetUtil.UTF_8))
36+
.addLast("stringDecoder", new StringDecoder(UTF_8))
3737
.addLast("blockHeaderEncoder", new BlockHeaderEncoder)
3838
.addLast("handler", new BlockServerHandler(dataProvider))
3939
}

0 commit comments

Comments
 (0)