Skip to content

Commit dbb4d83

Browse files
xuanyuankingHyukjinKwon
authored andcommitted
[SPARK-24215][PYSPARK] Implement _repr_html_ for dataframes in PySpark
## What changes were proposed in this pull request? Implement `_repr_html_` for PySpark while in notebook and add config named "spark.sql.repl.eagerEval.enabled" to control this. The dev list thread for context: http://apache-spark-developers-list.1001551.n3.nabble.com/eager-execution-and-debuggability-td23928.html ## How was this patch tested? New ut in DataFrameSuite and manual test in jupyter. Some screenshot below. **After:** ![image](https://user-images.githubusercontent.com/4833765/40268422-8db5bef0-5b9f-11e8-80f1-04bc654a4f2c.png) **Before:** ![image](https://user-images.githubusercontent.com/4833765/40268431-9f92c1b8-5b9f-11e8-9db9-0611f0940b26.png) Author: Yuanjian Li <[email protected]> Closes #21370 from xuanyuanking/SPARK-24215.
1 parent ff0501b commit dbb4d83

File tree

4 files changed

+176
-30
lines changed

4 files changed

+176
-30
lines changed

docs/configuration.md

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -456,6 +456,33 @@ Apart from these, the following properties are also available, and may be useful
456456
from JVM to Python worker for every task.
457457
</td>
458458
</tr>
459+
<tr>
460+
<td><code>spark.sql.repl.eagerEval.enabled</code></td>
461+
<td>false</td>
462+
<td>
463+
Enable eager evaluation or not. If true and the REPL you are using supports eager evaluation,
464+
Dataset will be ran automatically. The HTML table which generated by <code>_repl_html_</code>
465+
called by notebooks like Jupyter will feedback the queries user have defined. For plain Python
466+
REPL, the output will be shown like <code>dataframe.show()</code>
467+
(see <a href="https://issues.apache.org/jira/browse/SPARK-24215">SPARK-24215</a> for more details).
468+
</td>
469+
</tr>
470+
<tr>
471+
<td><code>spark.sql.repl.eagerEval.maxNumRows</code></td>
472+
<td>20</td>
473+
<td>
474+
Default number of rows in eager evaluation output HTML table generated by <code>_repr_html_</code> or plain text,
475+
this only take effect when <code>spark.sql.repl.eagerEval.enabled</code> is set to true.
476+
</td>
477+
</tr>
478+
<tr>
479+
<td><code>spark.sql.repl.eagerEval.truncate</code></td>
480+
<td>20</td>
481+
<td>
482+
Default number of truncate in eager evaluation output HTML table generated by <code>_repr_html_</code> or
483+
plain text, this only take effect when <code>spark.sql.repl.eagerEval.enabled</code> set to true.
484+
</td>
485+
</tr>
459486
<tr>
460487
<td><code>spark.files</code></td>
461488
<td></td>

python/pyspark/sql/dataframe.py

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,9 @@ def __init__(self, jdf, sql_ctx):
7878
self.is_cached = False
7979
self._schema = None # initialized lazily
8080
self._lazy_rdd = None
81+
# Check whether _repr_html is supported or not, we use it to avoid calling _jdf twice
82+
# by __repr__ and _repr_html_ while eager evaluation opened.
83+
self._support_repr_html = False
8184

8285
@property
8386
@since(1.3)
@@ -351,8 +354,68 @@ def show(self, n=20, truncate=True, vertical=False):
351354
else:
352355
print(self._jdf.showString(n, int(truncate), vertical))
353356

357+
@property
358+
def _eager_eval(self):
359+
"""Returns true if the eager evaluation enabled.
360+
"""
361+
return self.sql_ctx.getConf(
362+
"spark.sql.repl.eagerEval.enabled", "false").lower() == "true"
363+
364+
@property
365+
def _max_num_rows(self):
366+
"""Returns the max row number for eager evaluation.
367+
"""
368+
return int(self.sql_ctx.getConf(
369+
"spark.sql.repl.eagerEval.maxNumRows", "20"))
370+
371+
@property
372+
def _truncate(self):
373+
"""Returns the truncate length for eager evaluation.
374+
"""
375+
return int(self.sql_ctx.getConf(
376+
"spark.sql.repl.eagerEval.truncate", "20"))
377+
354378
def __repr__(self):
355-
return "DataFrame[%s]" % (", ".join("%s: %s" % c for c in self.dtypes))
379+
if not self._support_repr_html and self._eager_eval:
380+
vertical = False
381+
return self._jdf.showString(
382+
self._max_num_rows, self._truncate, vertical)
383+
else:
384+
return "DataFrame[%s]" % (", ".join("%s: %s" % c for c in self.dtypes))
385+
386+
def _repr_html_(self):
387+
"""Returns a dataframe with html code when you enabled eager evaluation
388+
by 'spark.sql.repl.eagerEval.enabled', this only called by REPL you are
389+
using support eager evaluation with HTML.
390+
"""
391+
import cgi
392+
if not self._support_repr_html:
393+
self._support_repr_html = True
394+
if self._eager_eval:
395+
max_num_rows = max(self._max_num_rows, 0)
396+
vertical = False
397+
sock_info = self._jdf.getRowsToPython(
398+
max_num_rows, self._truncate, vertical)
399+
rows = list(_load_from_socket(sock_info, BatchedSerializer(PickleSerializer())))
400+
head = rows[0]
401+
row_data = rows[1:]
402+
has_more_data = len(row_data) > max_num_rows
403+
row_data = row_data[:max_num_rows]
404+
405+
html = "<table border='1'>\n"
406+
# generate table head
407+
html += "<tr><th>%s</th></tr>\n" % "</th><th>".join(map(lambda x: cgi.escape(x), head))
408+
# generate table rows
409+
for row in row_data:
410+
html += "<tr><td>%s</td></tr>\n" % "</td><td>".join(
411+
map(lambda x: cgi.escape(x), row))
412+
html += "</table>\n"
413+
if has_more_data:
414+
html += "only showing top %d %s\n" % (
415+
max_num_rows, "row" if max_num_rows == 1 else "rows")
416+
return html
417+
else:
418+
return None
356419

357420
@since(2.1)
358421
def checkpoint(self, eager=True):

python/pyspark/sql/tests.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3074,6 +3074,36 @@ def test_checking_csv_header(self):
30743074
finally:
30753075
shutil.rmtree(path)
30763076

3077+
def test_repr_html(self):
3078+
import re
3079+
pattern = re.compile(r'^ *\|', re.MULTILINE)
3080+
df = self.spark.createDataFrame([(1, "1"), (22222, "22222")], ("key", "value"))
3081+
self.assertEquals(None, df._repr_html_())
3082+
with self.sql_conf({"spark.sql.repl.eagerEval.enabled": True}):
3083+
expected1 = """<table border='1'>
3084+
|<tr><th>key</th><th>value</th></tr>
3085+
|<tr><td>1</td><td>1</td></tr>
3086+
|<tr><td>22222</td><td>22222</td></tr>
3087+
|</table>
3088+
|"""
3089+
self.assertEquals(re.sub(pattern, '', expected1), df._repr_html_())
3090+
with self.sql_conf({"spark.sql.repl.eagerEval.truncate": 3}):
3091+
expected2 = """<table border='1'>
3092+
|<tr><th>key</th><th>value</th></tr>
3093+
|<tr><td>1</td><td>1</td></tr>
3094+
|<tr><td>222</td><td>222</td></tr>
3095+
|</table>
3096+
|"""
3097+
self.assertEquals(re.sub(pattern, '', expected2), df._repr_html_())
3098+
with self.sql_conf({"spark.sql.repl.eagerEval.maxNumRows": 1}):
3099+
expected3 = """<table border='1'>
3100+
|<tr><th>key</th><th>value</th></tr>
3101+
|<tr><td>1</td><td>1</td></tr>
3102+
|</table>
3103+
|only showing top 1 row
3104+
|"""
3105+
self.assertEquals(re.sub(pattern, '', expected3), df._repr_html_())
3106+
30773107

30783108
class HiveSparkSubmitTests(SparkSubmitTests):
30793109

sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala

Lines changed: 55 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -231,16 +231,17 @@ class Dataset[T] private[sql](
231231
}
232232

233233
/**
234-
* Compose the string representing rows for output
234+
* Get rows represented in Sequence by specific truncate and vertical requirement.
235235
*
236-
* @param _numRows Number of rows to show
236+
* @param numRows Number of rows to return
237237
* @param truncate If set to more than 0, truncates strings to `truncate` characters and
238238
* all cells will be aligned right.
239-
* @param vertical If set to true, prints output rows vertically (one line per column value).
239+
* @param vertical If set to true, the rows to return do not need truncate.
240240
*/
241-
private[sql] def showString(
242-
_numRows: Int, truncate: Int = 20, vertical: Boolean = false): String = {
243-
val numRows = _numRows.max(0).min(Int.MaxValue - 1)
241+
private[sql] def getRows(
242+
numRows: Int,
243+
truncate: Int,
244+
vertical: Boolean): Seq[Seq[String]] = {
244245
val newDf = toDF()
245246
val castCols = newDf.logicalPlan.output.map { col =>
246247
// Since binary types in top-level schema fields have a specific format to print,
@@ -251,14 +252,12 @@ class Dataset[T] private[sql](
251252
Column(col).cast(StringType)
252253
}
253254
}
254-
val takeResult = newDf.select(castCols: _*).take(numRows + 1)
255-
val hasMoreData = takeResult.length > numRows
256-
val data = takeResult.take(numRows)
255+
val data = newDf.select(castCols: _*).take(numRows + 1)
257256

258257
// For array values, replace Seq and Array with square brackets
259258
// For cells that are beyond `truncate` characters, replace it with the
260259
// first `truncate-3` and "..."
261-
val rows: Seq[Seq[String]] = schema.fieldNames.toSeq +: data.map { row =>
260+
schema.fieldNames.toSeq +: data.map { row =>
262261
row.toSeq.map { cell =>
263262
val str = cell match {
264263
case null => "null"
@@ -274,6 +273,26 @@ class Dataset[T] private[sql](
274273
}
275274
}: Seq[String]
276275
}
276+
}
277+
278+
/**
279+
* Compose the string representing rows for output
280+
*
281+
* @param _numRows Number of rows to show
282+
* @param truncate If set to more than 0, truncates strings to `truncate` characters and
283+
* all cells will be aligned right.
284+
* @param vertical If set to true, prints output rows vertically (one line per column value).
285+
*/
286+
private[sql] def showString(
287+
_numRows: Int,
288+
truncate: Int = 20,
289+
vertical: Boolean = false): String = {
290+
val numRows = _numRows.max(0).min(Int.MaxValue - 1)
291+
// Get rows represented by Seq[Seq[String]], we may get one more line if it has more data.
292+
val tmpRows = getRows(numRows, truncate, vertical)
293+
294+
val hasMoreData = tmpRows.length - 1 > numRows
295+
val rows = tmpRows.take(numRows + 1)
277296

278297
val sb = new StringBuilder
279298
val numCols = schema.fieldNames.length
@@ -291,31 +310,25 @@ class Dataset[T] private[sql](
291310
}
292311
}
293312

313+
val paddedRows = rows.map { row =>
314+
row.zipWithIndex.map { case (cell, i) =>
315+
if (truncate > 0) {
316+
StringUtils.leftPad(cell, colWidths(i))
317+
} else {
318+
StringUtils.rightPad(cell, colWidths(i))
319+
}
320+
}
321+
}
322+
294323
// Create SeparateLine
295324
val sep: String = colWidths.map("-" * _).addString(sb, "+", "+", "+\n").toString()
296325

297326
// column names
298-
rows.head.zipWithIndex.map { case (cell, i) =>
299-
if (truncate > 0) {
300-
StringUtils.leftPad(cell, colWidths(i))
301-
} else {
302-
StringUtils.rightPad(cell, colWidths(i))
303-
}
304-
}.addString(sb, "|", "|", "|\n")
305-
327+
paddedRows.head.addString(sb, "|", "|", "|\n")
306328
sb.append(sep)
307329

308330
// data
309-
rows.tail.foreach {
310-
_.zipWithIndex.map { case (cell, i) =>
311-
if (truncate > 0) {
312-
StringUtils.leftPad(cell.toString, colWidths(i))
313-
} else {
314-
StringUtils.rightPad(cell.toString, colWidths(i))
315-
}
316-
}.addString(sb, "|", "|", "|\n")
317-
}
318-
331+
paddedRows.tail.foreach(_.addString(sb, "|", "|", "|\n"))
319332
sb.append(sep)
320333
} else {
321334
// Extended display mode enabled
@@ -346,7 +359,7 @@ class Dataset[T] private[sql](
346359
}
347360

348361
// Print a footer
349-
if (vertical && data.isEmpty) {
362+
if (vertical && rows.tail.isEmpty) {
350363
// In a vertical mode, print an empty row set explicitly
351364
sb.append("(0 rows)\n")
352365
} else if (hasMoreData) {
@@ -3209,6 +3222,19 @@ class Dataset[T] private[sql](
32093222
}
32103223
}
32113224

3225+
private[sql] def getRowsToPython(
3226+
_numRows: Int,
3227+
truncate: Int,
3228+
vertical: Boolean): Array[Any] = {
3229+
EvaluatePython.registerPicklers()
3230+
val numRows = _numRows.max(0).min(Int.MaxValue - 1)
3231+
val rows = getRows(numRows, truncate, vertical).map(_.toArray).toArray
3232+
val toJava: (Any) => Any = EvaluatePython.toJava(_, ArrayType(ArrayType(StringType)))
3233+
val iter: Iterator[Array[Byte]] = new SerDeUtil.AutoBatchedPickler(
3234+
rows.iterator.map(toJava))
3235+
PythonRDD.serveIterator(iter, "serve-GetRows")
3236+
}
3237+
32123238
/**
32133239
* Collect a Dataset as ArrowPayload byte arrays and serve to PySpark.
32143240
*/

0 commit comments

Comments
 (0)