Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ public final class UnsafeSorterSpillReader extends UnsafeSorterIterator implemen

private byte[] arr = new byte[1024 * 1024];
private Object baseObject = arr;
private final long baseOffset = Platform.BYTE_ARRAY_OFFSET;
private final TaskContext taskContext = TaskContext.get();

public UnsafeSorterSpillReader(
Expand Down Expand Up @@ -132,7 +131,7 @@ public Object getBaseObject() {

@Override
public long getBaseOffset() {
return baseOffset;
return Platform.BYTE_ARRAY_OFFSET;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U

def create(schema1: StructType, schema2: StructType): UnsafeRowJoiner = {
val ctx = new CodegenContext
val offset = Platform.BYTE_ARRAY_OFFSET
val offset = "Platform.BYTE_ARRAY_OFFSET"
val getLong = "Platform.getLong"
val putLong = "Platform.putLong"

Expand Down Expand Up @@ -92,7 +92,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U
s"$getLong(obj2, offset2 + ${(i - bitset1Words) * 8})"
}
}
s"$putLong(buf, ${offset + i * 8}, $bits);\n"
s"$putLong(buf, $offset + ${i * 8}, $bits);\n"
}

val copyBitsets = ctx.splitExpressions(
Expand All @@ -102,12 +102,12 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U
("java.lang.Object", "obj2") :: ("long", "offset2") :: Nil)

// --------------------- copy fixed length portion from row 1 ----------------------- //
var cursor = offset + outputBitsetWords * 8
var cursor = outputBitsetWords * 8
val copyFixedLengthRow1 = s"""
|// Copy fixed length data for row1
|Platform.copyMemory(
| obj1, offset1 + ${bitset1Words * 8},
| buf, $cursor,
| buf, $offset + $cursor,
| ${schema1.size * 8});
""".stripMargin
cursor += schema1.size * 8
Expand All @@ -117,7 +117,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U
|// Copy fixed length data for row2
|Platform.copyMemory(
| obj2, offset2 + ${bitset2Words * 8},
| buf, $cursor,
| buf, $offset + $cursor,
| ${schema2.size * 8});
""".stripMargin
cursor += schema2.size * 8
Expand All @@ -129,7 +129,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U
|long numBytesVariableRow1 = row1.getSizeInBytes() - $numBytesBitsetAndFixedRow1;
|Platform.copyMemory(
| obj1, offset1 + ${(bitset1Words + schema1.size) * 8},
| buf, $cursor,
| buf, $offset + $cursor,
| numBytesVariableRow1);
""".stripMargin

Expand All @@ -140,7 +140,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U
|long numBytesVariableRow2 = row2.getSizeInBytes() - $numBytesBitsetAndFixedRow2;
|Platform.copyMemory(
| obj2, offset2 + ${(bitset2Words + schema2.size) * 8},
| buf, $cursor + numBytesVariableRow1,
| buf, $offset + $cursor + numBytesVariableRow1,
| numBytesVariableRow2);
""".stripMargin

Expand All @@ -161,7 +161,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U
} else {
s"(${(outputBitsetWords - bitset2Words + schema1.size) * 8}L + numBytesVariableRow1)"
}
val cursor = offset + outputBitsetWords * 8 + i * 8
val cursor = outputBitsetWords * 8 + i * 8
// UnsafeRow is a little underspecified, so in what follows we'll treat UnsafeRowWriter's
// output as a de-facto specification for the internal layout of data.
//
Expand Down Expand Up @@ -198,9 +198,9 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U
// Thus it is safe to perform `existingOffset != 0` checks here in the place of
// more expensive null-bit checks.
s"""
|existingOffset = $getLong(buf, $cursor);
|existingOffset = $getLong(buf, $offset + $cursor);
|if (existingOffset != 0) {
| $putLong(buf, $cursor, existingOffset + ($shift << 32));
| $putLong(buf, $offset + $cursor, existingOffset + ($shift << 32));
|}
""".stripMargin
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -666,7 +666,7 @@ case class MapFromEntries(child: Expression) extends UnaryExpression {
val keyArrayData = ctx.freshName("keyArrayData")
val valueArrayData = ctx.freshName("valueArrayData")

val baseOffset = Platform.BYTE_ARRAY_OFFSET
val baseOffset = "Platform.BYTE_ARRAY_OFFSET"
val keySize = dataType.keyType.defaultSize
val valueSize = dataType.valueType.defaultSize
val kByteSize = s"UnsafeArrayData.calculateSizeOfUnderlyingByteArray($numEntries, $keySize)"
Expand Down Expand Up @@ -696,8 +696,8 @@ case class MapFromEntries(child: Expression) extends UnaryExpression {
| final byte[] $data = new byte[(int)$byteArraySize];
| UnsafeMapData $unsafeMapData = new UnsafeMapData();
| Platform.putLong($data, $baseOffset, $keySectionSize);
| Platform.putLong($data, ${baseOffset + 8}, $numEntries);
| Platform.putLong($data, ${baseOffset + 8} + $keySectionSize, $numEntries);
| Platform.putLong($data, $baseOffset + 8, $numEntries);
| Platform.putLong($data, $baseOffset + 8 + $keySectionSize, $numEntries);
| $unsafeMapData.pointTo($data, $baseOffset, (int)$byteArraySize);
| ArrayData $keyArrayData = $unsafeMapData.keyArray();
| ArrayData $valueArrayData = $unsafeMapData.valueArray();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ private [sql] object GenArrayData {
val unsafeArraySizeInBytes =
UnsafeArrayData.calculateHeaderPortionInBytes(numElements) +
ByteArrayMethods.roundNumberOfBytesToNearestWord(elementType.defaultSize * numElements)
val baseOffset = Platform.BYTE_ARRAY_OFFSET
val baseOffset = "Platform.BYTE_ARRAY_OFFSET"

val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType)
val assignments = elementsCode.zipWithIndex.map { case (eval, i) =>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution

import org.scalatest.{Assertions, BeforeAndAfterEach, Matchers}
import org.scalatest.concurrent.TimeLimits

import org.apache.spark.{SparkFunSuite, TestUtils}
import org.apache.spark.deploy.SparkSubmitSuite
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{LocalSparkSession, QueryTest, Row, SparkSession}
import org.apache.spark.sql.functions.{array, col, count, lit}
import org.apache.spark.sql.types.IntegerType
import org.apache.spark.unsafe.Platform
import org.apache.spark.util.ResetSystemProperties

// Due to the need to set driver's extraJavaOptions, this test needs to use actual SparkSubmit.
class WholeStageCodegenSparkSubmitSuite extends SparkFunSuite
with Matchers
with BeforeAndAfterEach
with ResetSystemProperties {

test("Generated code on driver should not embed platform-specific constant") {
val unusedJar = TestUtils.createJarWithClasses(Seq.empty)

// HotSpot JVM specific: Set up a local cluster with the driver/executor using mismatched
// settings of UseCompressedOops JVM option.
val argsForSparkSubmit = Seq(
"--class", WholeStageCodegenSparkSubmitSuite.getClass.getName.stripSuffix("$"),
"--master", "local-cluster[1,1,1024]",
"--driver-memory", "1g",
"--conf", "spark.ui.enabled=false",
"--conf", "spark.master.rest.enabled=false",
"--conf", "spark.driver.extraJavaOptions=-XX:-UseCompressedOops",
"--conf", "spark.executor.extraJavaOptions=-XX:+UseCompressedOops",
unusedJar.toString)
SparkSubmitSuite.runSparkSubmit(argsForSparkSubmit, "../..")
}
}

object WholeStageCodegenSparkSubmitSuite extends Assertions with Logging {

var spark: SparkSession = _

def main(args: Array[String]): Unit = {
TestUtils.configTestLog4j("INFO")

spark = SparkSession.builder().getOrCreate()

// Make sure the test is run where the driver and the executors uses different object layouts
val driverArrayHeaderSize = Platform.BYTE_ARRAY_OFFSET
val executorArrayHeaderSize =
spark.sparkContext.range(0, 1).map(_ => Platform.BYTE_ARRAY_OFFSET).collect.head.toInt
assert(driverArrayHeaderSize > executorArrayHeaderSize)

val df = spark.range(71773).select((col("id") % lit(10)).cast(IntegerType) as "v")
.groupBy(array(col("v"))).agg(count(col("*")))
val plan = df.queryExecution.executedPlan
assert(plan.find(_.isInstanceOf[WholeStageCodegenExec]).isDefined)

val expectedAnswer =
Row(Array(0), 7178) ::
Row(Array(1), 7178) ::
Row(Array(2), 7178) ::
Row(Array(3), 7177) ::
Row(Array(4), 7177) ::
Row(Array(5), 7177) ::
Row(Array(6), 7177) ::
Row(Array(7), 7177) ::
Row(Array(8), 7177) ::
Row(Array(9), 7177) :: Nil
val result = df.collect
QueryTest.sameRows(result.toSeq, expectedAnswer) match {
case Some(errMsg) => fail(errMsg)
case _ =>
}
}
}