From 8bb3bf89fa07d98cf70a4088e61afedacd340022 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Tue, 11 Apr 2023 11:24:16 +0800 Subject: [PATCH 1/5] init init init init init --- .../spark/sql/connect/common/ProtoUtils.scala | 54 +++++++++++++++++++ .../service/SparkConnectStreamHandler.scala | 24 ++++----- 2 files changed, 64 insertions(+), 14 deletions(-) create mode 100644 connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/ProtoUtils.scala diff --git a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/ProtoUtils.scala b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/ProtoUtils.scala new file mode 100644 index 0000000000000..37a3bada082cc --- /dev/null +++ b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/ProtoUtils.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connect.common + +import scala.collection.JavaConverters._ + +import com.google.protobuf.{ByteString, Message} +import com.google.protobuf.Descriptors.FieldDescriptor + +private[connect] object ProtoUtils { + private val format = java.text.NumberFormat.getInstance() + + def redact(message: Message): Message = { + val builder = message.toBuilder + + message.getAllFields.asScala.iterator.foreach { + case (field: FieldDescriptor, bytes: ByteString) + if field.getJavaType == FieldDescriptor.JavaType.BYTE_STRING && bytes != null => + builder.setField( + field, + ByteString.copyFromUtf8(s"bytes (size=${format.format(bytes.size)})")) + + case (field: FieldDescriptor, bytes: Array[_]) + if field.getJavaType == FieldDescriptor.JavaType.BYTE_STRING && bytes != null => + builder.setField( + field, + ByteString.copyFromUtf8(s"bytes (size=${format.format(bytes.length)})")) + + // TODO: should also take 1, repeated msg; 2, map into account + case (field: FieldDescriptor, msg: Message) + if field.getJavaType == FieldDescriptor.JavaType.MESSAGE && msg != null => + builder.setField(field, redact(msg)) + + case (field: FieldDescriptor, value: Any) => builder.setField(field, value) + } + + builder.build() + } +} diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala index 6462976aebb81..212344d7a409a 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala @@ -32,7 +32,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.connect.artifact.SparkConnectArtifactManager -import org.apache.spark.sql.connect.common.DataTypeProtoConverter +import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, ProtoUtils} import org.apache.spark.sql.connect.common.LiteralValueProtoConverter.toLiteralProto import org.apache.spark.sql.connect.config.Connect.CONNECT_GRPC_ARROW_MAX_BATCH_SIZE import org.apache.spark.sql.connect.planner.SparkConnectPlanner @@ -54,19 +54,15 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[ExecutePlanResp session.withActive { // Add debug information to the query execution so that the jobs are traceable. - try { - val debugString = - Utils.redact(session.sessionState.conf.stringRedactionPattern, v.toString) - session.sparkContext.setLocalProperty( - "callSite.short", - s"Spark Connect - ${StringUtils.abbreviate(debugString, 128)}") - session.sparkContext.setLocalProperty( - "callSite.long", - StringUtils.abbreviate(debugString, 2048)) - } catch { - case e: Throwable => - logWarning("Fail to extract or attach the debug information", e) - } + val debugString = Utils.redact( + session.sessionState.conf.stringRedactionPattern, + ProtoUtils.redact(v).toString) + session.sparkContext.setLocalProperty( + "callSite.short", + s"Spark Connect - ${StringUtils.abbreviate(debugString, 128)}") + session.sparkContext.setLocalProperty( + "callSite.long", + StringUtils.abbreviate(debugString, 2048)) v.getPlan.getOpTypeCase match { case proto.Plan.OpTypeCase.COMMAND => handleCommand(session, v) From 326d78327e68c7e5b8b99ef79bf0fc9c2cc90285 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Wed, 12 Apr 2023 16:53:44 +0800 Subject: [PATCH 2/5] address comments --- .../spark/sql/connect/common/ProtoUtils.scala | 52 +++++++++++++------ .../service/SparkConnectStreamHandler.scala | 2 +- 2 files changed, 38 insertions(+), 16 deletions(-) diff --git a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/ProtoUtils.scala b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/ProtoUtils.scala index 37a3bada082cc..1a662db2aace2 100644 --- a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/ProtoUtils.scala +++ b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/ProtoUtils.scala @@ -24,31 +24,53 @@ import com.google.protobuf.Descriptors.FieldDescriptor private[connect] object ProtoUtils { private val format = java.text.NumberFormat.getInstance() + private val NUM_FIRST_BYTES = 8 - def redact(message: Message): Message = { + def abbreviateBytes(message: Message): Message = { val builder = message.toBuilder message.getAllFields.asScala.iterator.foreach { - case (field: FieldDescriptor, bytes: ByteString) - if field.getJavaType == FieldDescriptor.JavaType.BYTE_STRING && bytes != null => - builder.setField( - field, - ByteString.copyFromUtf8(s"bytes (size=${format.format(bytes.size)})")) - - case (field: FieldDescriptor, bytes: Array[_]) - if field.getJavaType == FieldDescriptor.JavaType.BYTE_STRING && bytes != null => - builder.setField( - field, - ByteString.copyFromUtf8(s"bytes (size=${format.format(bytes.length)})")) - - // TODO: should also take 1, repeated msg; 2, map into account + case (field: FieldDescriptor, byteString: ByteString) + if field.getJavaType == FieldDescriptor.JavaType.BYTE_STRING && byteString != null => + val size = byteString.size() + if (size > NUM_FIRST_BYTES) { + val bytes = Array.ofDim[Byte](NUM_FIRST_BYTES) + var i = 0 + while (i < NUM_FIRST_BYTES) { + bytes(i) = byteString.byteAt(i) + i += 1 + } + builder.setField(field, createByteString(Some(bytes), size)) + } else { + builder.setField(field, createByteString(None, size)) + } + + case (field: FieldDescriptor, byteArray: Array[Byte]) + if field.getJavaType == FieldDescriptor.JavaType.BYTE_STRING && byteArray != null => + val size = byteArray.length + if (size > NUM_FIRST_BYTES) { + builder.setField(field, createByteString(Some(byteArray.take(NUM_FIRST_BYTES)), size)) + } else { + builder.setField(field, createByteString(None, size)) + } + + // TODO: should also support 1, repeated msg; 2, map case (field: FieldDescriptor, msg: Message) if field.getJavaType == FieldDescriptor.JavaType.MESSAGE && msg != null => - builder.setField(field, redact(msg)) + builder.setField(field, abbreviateBytes(msg)) case (field: FieldDescriptor, value: Any) => builder.setField(field, value) } builder.build() } + + private def createByteString(firstBytes: Option[Array[Byte]], size: Int): ByteString = { + var byteStrings = Array.empty[ByteString] + firstBytes.foreach { bytes => + byteStrings :+= ByteString.copyFrom(bytes) + } + byteStrings :+= ByteString.copyFromUtf8(s"*********(redacted, size=${format.format(size)})") + ByteString.copyFrom(byteStrings.toIterable.asJava) + } } diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala index 212344d7a409a..88191ca28c3cf 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala @@ -56,7 +56,7 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[ExecutePlanResp // Add debug information to the query execution so that the jobs are traceable. val debugString = Utils.redact( session.sessionState.conf.stringRedactionPattern, - ProtoUtils.redact(v).toString) + ProtoUtils.abbreviateBytes(v).toString) session.sparkContext.setLocalProperty( "callSite.short", s"Spark Connect - ${StringUtils.abbreviate(debugString, 128)}") From ef7088521f7abca0b5d7335f99e8393cad4cec98 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Wed, 12 Apr 2023 19:32:23 +0800 Subject: [PATCH 3/5] address comments --- .../spark/sql/connect/common/ProtoUtils.scala | 56 +++++++++++-------- .../service/SparkConnectStreamHandler.scala | 24 +++++--- 2 files changed, 47 insertions(+), 33 deletions(-) diff --git a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/ProtoUtils.scala b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/ProtoUtils.scala index 1a662db2aace2..fade524ccac55 100644 --- a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/ProtoUtils.scala +++ b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/ProtoUtils.scala @@ -24,40 +24,46 @@ import com.google.protobuf.Descriptors.FieldDescriptor private[connect] object ProtoUtils { private val format = java.text.NumberFormat.getInstance() - private val NUM_FIRST_BYTES = 8 + private val MAX_BYTES_SIZE = 8 + private val MAX_STRING_SIZE = 1024 - def abbreviateBytes(message: Message): Message = { + def abbreviate(message: Message): Message = { val builder = message.toBuilder message.getAllFields.asScala.iterator.foreach { + case (field: FieldDescriptor, string: String) + if field.getJavaType == FieldDescriptor.JavaType.STRING && string != null => + val size = string.size + if (size > MAX_STRING_SIZE) { + builder.setField(field, createString(string.take(MAX_STRING_SIZE), size)) + } else { + builder.setField(field, string) + } + case (field: FieldDescriptor, byteString: ByteString) if field.getJavaType == FieldDescriptor.JavaType.BYTE_STRING && byteString != null => - val size = byteString.size() - if (size > NUM_FIRST_BYTES) { - val bytes = Array.ofDim[Byte](NUM_FIRST_BYTES) - var i = 0 - while (i < NUM_FIRST_BYTES) { - bytes(i) = byteString.byteAt(i) - i += 1 - } - builder.setField(field, createByteString(Some(bytes), size)) + val size = byteString.size + if (size > MAX_BYTES_SIZE) { + val prefix = Array.tabulate(MAX_BYTES_SIZE)(byteString.byteAt) + builder.setField(field, createByteString(prefix, size)) } else { - builder.setField(field, createByteString(None, size)) + builder.setField(field, byteString) } case (field: FieldDescriptor, byteArray: Array[Byte]) if field.getJavaType == FieldDescriptor.JavaType.BYTE_STRING && byteArray != null => - val size = byteArray.length - if (size > NUM_FIRST_BYTES) { - builder.setField(field, createByteString(Some(byteArray.take(NUM_FIRST_BYTES)), size)) + val size = byteArray.size + if (size > MAX_BYTES_SIZE) { + val prefix = byteArray.take(MAX_BYTES_SIZE) + builder.setField(field, createByteString(prefix, size)) } else { - builder.setField(field, createByteString(None, size)) + builder.setField(field, byteArray) } // TODO: should also support 1, repeated msg; 2, map case (field: FieldDescriptor, msg: Message) if field.getJavaType == FieldDescriptor.JavaType.MESSAGE && msg != null => - builder.setField(field, abbreviateBytes(msg)) + builder.setField(field, abbreviate(msg)) case (field: FieldDescriptor, value: Any) => builder.setField(field, value) } @@ -65,12 +71,14 @@ private[connect] object ProtoUtils { builder.build() } - private def createByteString(firstBytes: Option[Array[Byte]], size: Int): ByteString = { - var byteStrings = Array.empty[ByteString] - firstBytes.foreach { bytes => - byteStrings :+= ByteString.copyFrom(bytes) - } - byteStrings :+= ByteString.copyFromUtf8(s"*********(redacted, size=${format.format(size)})") - ByteString.copyFrom(byteStrings.toIterable.asJava) + private def createByteString(prefix: Array[Byte], size: Int): ByteString = { + ByteString.copyFrom( + List( + ByteString.copyFrom(prefix), + ByteString.copyFromUtf8(s"*********(redacted, size=${format.format(size)})")).asJava) + } + + private def createString(prefix: String, size: Int): String = { + s"$prefix*********(redacted, size=${format.format(size)})" } } diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala index 88191ca28c3cf..8e13b4d14f77f 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala @@ -54,15 +54,21 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[ExecutePlanResp session.withActive { // Add debug information to the query execution so that the jobs are traceable. - val debugString = Utils.redact( - session.sessionState.conf.stringRedactionPattern, - ProtoUtils.abbreviateBytes(v).toString) - session.sparkContext.setLocalProperty( - "callSite.short", - s"Spark Connect - ${StringUtils.abbreviate(debugString, 128)}") - session.sparkContext.setLocalProperty( - "callSite.long", - StringUtils.abbreviate(debugString, 2048)) + try { + val debugString = + Utils.redact( + session.sessionState.conf.stringRedactionPattern, + ProtoUtils.abbreviate(v).toString) + session.sparkContext.setLocalProperty( + "callSite.short", + s"Spark Connect - ${StringUtils.abbreviate(debugString, 128)}") + session.sparkContext.setLocalProperty( + "callSite.long", + StringUtils.abbreviate(debugString, 2048)) + } catch { + case e: Throwable => + logWarning("Fail to extract or attach the debug information", e) + } v.getPlan.getOpTypeCase match { case proto.Plan.OpTypeCase.COMMAND => handleCommand(session, v) From ff40fbf7dd8a5ae9a3cbdb5d6d0450fdcbd4ebda Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Thu, 13 Apr 2023 08:29:32 +0800 Subject: [PATCH 4/5] address comments --- .../org/apache/spark/sql/connect/common/ProtoUtils.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/ProtoUtils.scala b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/ProtoUtils.scala index fade524ccac55..0ef92b84754f9 100644 --- a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/ProtoUtils.scala +++ b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/ProtoUtils.scala @@ -75,10 +75,10 @@ private[connect] object ProtoUtils { ByteString.copyFrom( List( ByteString.copyFrom(prefix), - ByteString.copyFromUtf8(s"*********(redacted, size=${format.format(size)})")).asJava) + ByteString.copyFromUtf8(s"[truncated(size=${format.format(size)})]")).asJava) } private def createString(prefix: String, size: Int): String = { - s"$prefix*********(redacted, size=${format.format(size)})" + s"$prefix[truncated(size=${format.format(size)})]" } } From 17986143accc9d6542f6365fb6b9ca1528bcbd6e Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Thu, 13 Apr 2023 11:21:25 +0800 Subject: [PATCH 5/5] add a ticket --- .../scala/org/apache/spark/sql/connect/common/ProtoUtils.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/ProtoUtils.scala b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/ProtoUtils.scala index 0ef92b84754f9..83f84f45b317e 100644 --- a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/ProtoUtils.scala +++ b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/ProtoUtils.scala @@ -60,7 +60,7 @@ private[connect] object ProtoUtils { builder.setField(field, byteArray) } - // TODO: should also support 1, repeated msg; 2, map + // TODO(SPARK-43117): should also support 1, repeated msg; 2, map case (field: FieldDescriptor, msg: Message) if field.getJavaType == FieldDescriptor.JavaType.MESSAGE && msg != null => builder.setField(field, abbreviate(msg))