Skip to content

Commit 02ce522

Browse files
committed
Use test utils and clean up the examples in doctests in table and toTable
1 parent 4b19f49 commit 02ce522

File tree

2 files changed

+28
-40
lines changed

2 files changed

+28
-40
lines changed

python/pyspark/sql/streaming.py

Lines changed: 9 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -974,9 +974,7 @@ def table(self, tableName):
974974
975975
Examples
976976
--------
977-
>>> csv_sdf = spark.readStream.table('input_table') # doctest: +SKIP
978-
>>> csv_sdf.isStreaming # doctest: +SKIP
979-
True
977+
>>> spark.readStream.table('input_table') # doctest: +SKIP
980978
"""
981979
if isinstance(tableName, str):
982980
return self._df(self._jreader.table(tableName))
@@ -1535,23 +1533,15 @@ def toTable(self, tableName, format=None, outputMode=None, partitionBy=None, que
15351533
15361534
Examples
15371535
--------
1538-
>>> sq = sdf.writeStream.format('parquet').queryName('this_query').option(
1539-
... 'checkpointLocation', '/tmp/checkpoint').toTable('output_table') # doctest: +SKIP
1540-
>>> sq.isActive # doctest: +SKIP
1541-
True
1542-
>>> sq.name # doctest: +SKIP
1543-
'this_query'
1544-
>>> sq.stop() # doctest: +SKIP
1545-
>>> sq.isActive # doctest: +SKIP
1546-
False
1547-
>>> sq = sdf.writeStream.trigger(processingTime='5 seconds').toTable(
1548-
... 'output_table', queryName='that_query', outputMode="append", format='parquet',
1536+
>>> sdf.writeStream.format('parquet').queryName('query').toTable('output_table')
1537+
... # doctest: +SKIP
1538+
1539+
>>> sdf.writeStream.trigger(processingTime='5 seconds').toTable(
1540+
... 'output_table',
1541+
... queryName='that_query',
1542+
... outputMode="append",
1543+
... format='parquet',
15491544
... checkpointLocation='/tmp/checkpoint') # doctest: +SKIP
1550-
>>> sq.name # doctest: +SKIP
1551-
'that_query'
1552-
>>> sq.isActive # doctest: +SKIP
1553-
True
1554-
>>> sq.stop() # doctest: +SKIP
15551545
"""
15561546
# TODO(SPARK-33659): document the current behavior for DataStreamWriter.toTable API
15571547
self.options(**options)

python/pyspark/sql/tests/test_streaming.py

Lines changed: 19 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
import shutil
2020
import tempfile
2121
import time
22-
from random import randint
2322

2423
from pyspark.sql import Row
2524
from pyspark.sql.functions import lit
@@ -572,28 +571,27 @@ def collectBatch(df, id):
572571
q.stop()
573572

574573
def test_streaming_read_from_table(self):
575-
input_table_name = "sample_input_table_%d" % randint(0, 100000000)
576-
self.spark.sql("CREATE TABLE %s (value string) USING parquet" % input_table_name)
577-
self.spark.sql("INSERT INTO %s VALUES ('aaa'), ('bbb'), ('ccc')" % input_table_name)
578-
df = self.spark.readStream.table(input_table_name)
579-
self.assertTrue(df.isStreaming)
580-
q = df.writeStream.format('memory').queryName('this_query').start()
581-
q.processAllAvailable()
582-
q.stop()
583-
result = self.spark.sql("SELECT * FROM this_query ORDER BY value").collect()
584-
self.assertEqual([Row(value='aaa'), Row(value='bbb'), Row(value='ccc')], result)
574+
with self.table("input_table", "this_query"):
575+
self.spark.sql("CREATE TABLE input_table (value string) USING parquet")
576+
self.spark.sql("INSERT INTO input_table VALUES ('aaa'), ('bbb'), ('ccc')")
577+
df = self.spark.readStream.table("input_table")
578+
self.assertTrue(df.isStreaming)
579+
q = df.writeStream.format('memory').queryName('this_query').start()
580+
q.processAllAvailable()
581+
q.stop()
582+
result = self.spark.sql("SELECT * FROM this_query ORDER BY value").collect()
583+
self.assertEqual(
584+
set([Row(value='aaa'), Row(value='bbb'), Row(value='ccc')]), set(result))
585585

586586
def test_streaming_write_to_table(self):
587-
output_table_name = "sample_output_table_%d" % randint(0, 100000000)
588-
tmpPath = tempfile.mkdtemp()
589-
shutil.rmtree(tmpPath)
590-
df = self.spark.readStream.format("rate").option("rowsPerSecond", 10).load()
591-
q = df.writeStream.toTable(output_table_name, format='parquet', checkpointLocation=tmpPath)
592-
self.assertTrue(q.isActive)
593-
time.sleep(3)
594-
q.stop()
595-
result = self.spark.sql("SELECT value FROM %s" % output_table_name).collect()
596-
self.assertTrue(len(result) > 0)
587+
with self.table("output_table"), tempfile.TemporaryDirectory() as tmpdir:
588+
df = self.spark.readStream.format("rate").option("rowsPerSecond", 10).load()
589+
q = df.writeStream.toTable("output_table", format='parquet', checkpointLocation=tmpdir)
590+
self.assertTrue(q.isActive)
591+
time.sleep(3)
592+
q.stop()
593+
result = self.spark.sql("SELECT value FROM output_table").collect()
594+
self.assertTrue(len(result) > 0)
597595

598596

599597
if __name__ == "__main__":

0 commit comments

Comments
 (0)