diff --git a/python/pyspark/sql/connect/expressions.py b/python/pyspark/sql/connect/expressions.py index db1cd1c013be..63128ef48e38 100644 --- a/python/pyspark/sql/connect/expressions.py +++ b/python/pyspark/sql/connect/expressions.py @@ -477,8 +477,20 @@ def to_plan(self, session: "SparkConnectClient") -> "proto.Expression": def __repr__(self) -> str: if self._value is None: return "NULL" - else: - return f"{self._value}" + elif isinstance(self._dataType, DateType): + dt = DateType().fromInternal(self._value) + if dt is not None and isinstance(dt, datetime.date): + return dt.strftime("%Y-%m-%d") + elif isinstance(self._dataType, TimestampType): + ts = TimestampType().fromInternal(self._value) + if ts is not None and isinstance(ts, datetime.datetime): + return ts.strftime("%Y-%m-%d %H:%M:%S.%f") + elif isinstance(self._dataType, TimestampNTZType): + ts = TimestampNTZType().fromInternal(self._value) + if ts is not None and isinstance(ts, datetime.datetime): + return ts.strftime("%Y-%m-%d %H:%M:%S.%f") + # TODO(SPARK-49693): Refine the string representation of timedelta + return f"{self._value}" class ColumnReference(Expression): diff --git a/python/pyspark/sql/tests/test_column.py b/python/pyspark/sql/tests/test_column.py index 2bd66baaa2bf..220ecd387f7e 100644 --- a/python/pyspark/sql/tests/test_column.py +++ b/python/pyspark/sql/tests/test_column.py @@ -18,6 +18,8 @@ from enum import Enum from itertools import chain +import datetime + from pyspark.sql import Column, Row from pyspark.sql import functions as sf from pyspark.sql.types import StructType, StructField, IntegerType, LongType @@ -280,6 +282,13 @@ def test_expr_str_representation(self): when_cond = sf.when(expression, sf.lit(None)) self.assertEqual(str(when_cond), "Column<'CASE WHEN foo THEN NULL END'>") + def test_lit_time_representation(self): + dt = datetime.date(2021, 3, 4) + self.assertEqual(str(sf.lit(dt)), "Column<'2021-03-04'>") + + ts = datetime.datetime(2021, 3, 4, 12, 34, 56, 1234) + self.assertEqual(str(sf.lit(ts)), "Column<'2021-03-04 12:34:56.001234'>") + def test_enum_literals(self): class IntEnum(Enum): X = 1