|
19 | 19 | import shutil |
20 | 20 | import tempfile |
21 | 21 | import time |
22 | | -from random import randint |
23 | 22 |
|
24 | 23 | from pyspark.sql import Row |
25 | 24 | from pyspark.sql.functions import lit |
@@ -572,28 +571,27 @@ def collectBatch(df, id): |
572 | 571 | q.stop() |
573 | 572 |
|
574 | 573 | def test_streaming_read_from_table(self): |
575 | | - input_table_name = "sample_input_table_%d" % randint(0, 100000000) |
576 | | - self.spark.sql("CREATE TABLE %s (value string) USING parquet" % input_table_name) |
577 | | - self.spark.sql("INSERT INTO %s VALUES ('aaa'), ('bbb'), ('ccc')" % input_table_name) |
578 | | - df = self.spark.readStream.table(input_table_name) |
579 | | - self.assertTrue(df.isStreaming) |
580 | | - q = df.writeStream.format('memory').queryName('this_query').start() |
581 | | - q.processAllAvailable() |
582 | | - q.stop() |
583 | | - result = self.spark.sql("SELECT * FROM this_query ORDER BY value").collect() |
584 | | - self.assertEqual([Row(value='aaa'), Row(value='bbb'), Row(value='ccc')], result) |
| 574 | + with self.table("input_table", "this_query"): |
| 575 | + self.spark.sql("CREATE TABLE input_table (value string) USING parquet") |
| 576 | + self.spark.sql("INSERT INTO input_table VALUES ('aaa'), ('bbb'), ('ccc')") |
| 577 | + df = self.spark.readStream.table("input_table") |
| 578 | + self.assertTrue(df.isStreaming) |
| 579 | + q = df.writeStream.format('memory').queryName('this_query').start() |
| 580 | + q.processAllAvailable() |
| 581 | + q.stop() |
| 582 | + result = self.spark.sql("SELECT * FROM this_query ORDER BY value").collect() |
| 583 | + self.assertEqual( |
| 584 | + set([Row(value='aaa'), Row(value='bbb'), Row(value='ccc')]), set(result)) |
585 | 585 |
|
586 | 586 | def test_streaming_write_to_table(self): |
587 | | - output_table_name = "sample_output_table_%d" % randint(0, 100000000) |
588 | | - tmpPath = tempfile.mkdtemp() |
589 | | - shutil.rmtree(tmpPath) |
590 | | - df = self.spark.readStream.format("rate").option("rowsPerSecond", 10).load() |
591 | | - q = df.writeStream.toTable(output_table_name, format='parquet', checkpointLocation=tmpPath) |
592 | | - self.assertTrue(q.isActive) |
593 | | - time.sleep(3) |
594 | | - q.stop() |
595 | | - result = self.spark.sql("SELECT value FROM %s" % output_table_name).collect() |
596 | | - self.assertTrue(len(result) > 0) |
| 587 | + with self.table("output_table"), tempfile.TemporaryDirectory() as tmpdir: |
| 588 | + df = self.spark.readStream.format("rate").option("rowsPerSecond", 10).load() |
| 589 | + q = df.writeStream.toTable("output_table", format='parquet', checkpointLocation=tmpdir) |
| 590 | + self.assertTrue(q.isActive) |
| 591 | + time.sleep(3) |
| 592 | + q.stop() |
| 593 | + result = self.spark.sql("SELECT value FROM output_table").collect() |
| 594 | + self.assertTrue(len(result) > 0) |
597 | 595 |
|
598 | 596 |
|
599 | 597 | if __name__ == "__main__": |
|
0 commit comments