From 33725ce4f8632fbb12e9707e71d5cb897e3b6038 Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Tue, 18 Sep 2018 09:22:57 -0500 Subject: [PATCH] [SPARK-25456][SQL][TEST] Fix PythonForeachWriterSuite --- .../python/PythonForeachWriterSuite.scala | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonForeachWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonForeachWriterSuite.scala index 07e603477012..d02014c0dee5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonForeachWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonForeachWriterSuite.scala @@ -19,17 +19,20 @@ package org.apache.spark.sql.execution.python import scala.collection.mutable.ArrayBuffer +import org.mockito.Mockito.when import org.scalatest.concurrent.Eventually +import org.scalatest.mockito.MockitoSugar import org.scalatest.time.SpanSugar._ import org.apache.spark._ import org.apache.spark.memory.{TaskMemoryManager, TestMemoryManager} +import org.apache.spark.serializer.{JavaSerializer, SerializerManager} import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection} import org.apache.spark.sql.execution.python.PythonForeachWriter.UnsafeRowBuffer import org.apache.spark.sql.types.{DataType, IntegerType} import org.apache.spark.util.Utils -class PythonForeachWriterSuite extends SparkFunSuite with Eventually { +class PythonForeachWriterSuite extends SparkFunSuite with Eventually with MockitoSugar { testWithBuffer("UnsafeRowBuffer: iterator blocks when no data is available") { b => b.assertIteratorBlocked() @@ -75,7 +78,7 @@ class PythonForeachWriterSuite extends SparkFunSuite with Eventually { tester = new BufferTester(memBytes, sleepPerRowReadMs) f(tester) } finally { - if (tester == null) tester.close() + if (tester != null) tester.close() } } } @@ -83,7 +86,12 @@ class PythonForeachWriterSuite extends SparkFunSuite with Eventually { class BufferTester(memBytes: Long, sleepPerRowReadMs: Int) { private val buffer = { - val mem = new TestMemoryManager(new SparkConf()) + val mockEnv = mock[SparkEnv] + val conf = new SparkConf() + val serializerManager = new SerializerManager(new JavaSerializer(conf), conf, None) + when(mockEnv.serializerManager).thenReturn(serializerManager) + SparkEnv.set(mockEnv) + val mem = new TestMemoryManager(conf) mem.limit(memBytes) val taskM = new TaskMemoryManager(mem, 0) new UnsafeRowBuffer(taskM, Utils.createTempDir(), 1)