Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1099,7 +1099,7 @@ def trunc(date, format):
"""
Returns date truncated to the unit specified by the format.

:param format: 'year', 'YYYY', 'yy' or 'month', 'mon', 'mm'
:param format: 'year', 'yyyy', 'yy' or 'month', 'mon', 'mm'

>>> df = spark.createDataFrame([('1997-02-28',)], ['d'])
>>> df.select(trunc(df.d, 'year').alias('year')).collect()
Expand All @@ -1111,6 +1111,24 @@ def trunc(date, format):
return Column(sc._jvm.functions.trunc(_to_java_column(date), format))


@since(2.3)
def date_trunc(format, timestamp):
"""
Returns timestamp truncated to the unit specified by the format.

:param format: 'year', 'yyyy', 'yy', 'month', 'mon', 'mm',
'day', 'dd', 'hour', 'minute', 'second', 'week', 'quarter'

>>> df = spark.createDataFrame([('1997-02-28 05:02:11',)], ['t'])
>>> df.select(date_trunc('year', df.t).alias('year')).collect()
[Row(year=datetime.datetime(1997, 1, 1, 0, 0))]
>>> df.select(date_trunc('mon', df.t).alias('month')).collect()
[Row(month=datetime.datetime(1997, 2, 1, 0, 0))]
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.date_trunc(format, _to_java_column(timestamp)))


@since(1.5)
def next_day(date, dayOfWeek):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,7 @@ object FunctionRegistry {
expression[ToUnixTimestamp]("to_unix_timestamp"),
expression[ToUTCTimestamp]("to_utc_timestamp"),
expression[TruncDate]("trunc"),
expression[TruncTimestamp]("date_trunc"),
expression[UnixTimestamp]("unix_timestamp"),
expression[DayOfWeek]("dayofweek"),
expression[WeekOfYear]("weekofyear"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1295,87 +1295,181 @@ case class ParseToTimestamp(left: Expression, format: Option[Expression], child:
override def dataType: DataType = TimestampType
}

/**
* Returns date truncated to the unit specified by the format.
*/
// scalastyle:off line.size.limit
@ExpressionDescription(
usage = "_FUNC_(date, fmt) - Returns `date` with the time portion of the day truncated to the unit specified by the format model `fmt`.",
examples = """
Examples:
> SELECT _FUNC_('2009-02-12', 'MM');
2009-02-01
> SELECT _FUNC_('2015-10-27', 'YEAR');
2015-01-01
""",
since = "1.5.0")
// scalastyle:on line.size.limit
case class TruncDate(date: Expression, format: Expression)
extends BinaryExpression with ImplicitCastInputTypes {
override def left: Expression = date
override def right: Expression = format

override def inputTypes: Seq[AbstractDataType] = Seq(DateType, StringType)
override def dataType: DataType = DateType
trait TruncInstant extends BinaryExpression with ImplicitCastInputTypes {
val instant: Expression
val format: Expression
override def nullable: Boolean = true
override def prettyName: String = "trunc"

private lazy val truncLevel: Int =
DateTimeUtils.parseTruncLevel(format.eval().asInstanceOf[UTF8String])

override def eval(input: InternalRow): Any = {
/**
* @param input internalRow (time)
* @param maxLevel Maximum level that can be used for truncation (e.g MONTH for Date input)
* @param truncFunc function: (time, level) => time
*/
protected def evalHelper(input: InternalRow, maxLevel: Int)(
truncFunc: (Any, Int) => Any): Any = {
val level = if (format.foldable) {
truncLevel
} else {
DateTimeUtils.parseTruncLevel(format.eval().asInstanceOf[UTF8String])
}
if (level == -1) {
// unknown format
if (level == DateTimeUtils.TRUNC_INVALID || level > maxLevel) {
// unknown format or too large level
null
} else {
val d = date.eval(input)
if (d == null) {
val t = instant.eval(input)
if (t == null) {
null
} else {
DateTimeUtils.truncDate(d.asInstanceOf[Int], level)
truncFunc(t, level)
}
}
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
protected def codeGenHelper(
ctx: CodegenContext,
ev: ExprCode,
maxLevel: Int,
orderReversed: Boolean = false)(
truncFunc: (String, String) => String)
: ExprCode = {
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")

if (format.foldable) {
if (truncLevel == -1) {
if (truncLevel == DateTimeUtils.TRUNC_INVALID || truncLevel > maxLevel) {
ev.copy(code = s"""
boolean ${ev.isNull} = true;
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};""")
} else {
val d = date.genCode(ctx)
val t = instant.genCode(ctx)
val truncFuncStr = truncFunc(t.value, truncLevel.toString)
ev.copy(code = s"""
${d.code}
boolean ${ev.isNull} = ${d.isNull};
${t.code}
boolean ${ev.isNull} = ${t.isNull};
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
if (!${ev.isNull}) {
${ev.value} = $dtu.truncDate(${d.value}, $truncLevel);
${ev.value} = $dtu.$truncFuncStr;
}""")
}
} else {
nullSafeCodeGen(ctx, ev, (dateVal, fmt) => {
nullSafeCodeGen(ctx, ev, (left, right) => {
val form = ctx.freshName("form")
val (dateVal, fmt) = if (orderReversed) {
(right, left)
} else {
(left, right)
}
val truncFuncStr = truncFunc(dateVal, form)
s"""
int $form = $dtu.parseTruncLevel($fmt);
if ($form == -1) {
if ($form == -1 || $form > $maxLevel) {
${ev.isNull} = true;
} else {
${ev.value} = $dtu.truncDate($dateVal, $form);
${ev.value} = $dtu.$truncFuncStr
}
"""
})
}
}
}

/**
* Returns date truncated to the unit specified by the format.
*/
// scalastyle:off line.size.limit
@ExpressionDescription(
usage = """
_FUNC_(date, fmt) - Returns `date` with the time portion of the day truncated to the unit specified by the format model `fmt`.
`fmt` should be one of ["year", "yyyy", "yy", "mon", "month", "mm"]
""",
examples = """
Examples:
> SELECT _FUNC_('2009-02-12', 'MM');
2009-02-01
> SELECT _FUNC_('2015-10-27', 'YEAR');
2015-01-01
""",
since = "1.5.0")
// scalastyle:on line.size.limit
case class TruncDate(date: Expression, format: Expression)
extends TruncInstant {
override def left: Expression = date
override def right: Expression = format

override def inputTypes: Seq[AbstractDataType] = Seq(DateType, StringType)
override def dataType: DataType = DateType
override def prettyName: String = "trunc"
override val instant = date

override def eval(input: InternalRow): Any = {
evalHelper(input, maxLevel = DateTimeUtils.TRUNC_TO_MONTH) { (d: Any, level: Int) =>
DateTimeUtils.truncDate(d.asInstanceOf[Int], level)
}
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
codeGenHelper(ctx, ev, maxLevel = DateTimeUtils.TRUNC_TO_MONTH) { (date: String, fmt: String) =>
s"truncDate($date, $fmt);"
}
}
}

/**
* Returns timestamp truncated to the unit specified by the format.
*/
// scalastyle:off line.size.limit
@ExpressionDescription(
usage = """
_FUNC_(fmt, ts) - Returns timestamp `ts` truncated to the unit specified by the format model `fmt`.
`fmt` should be one of ["YEAR", "YYYY", "YY", "MON", "MONTH", "MM", "DAY", "DD", "HOUR", "MINUTE", "SECOND", "WEEK", "QUARTER"]
""",
examples = """
Examples:
> SELECT _FUNC_('2015-03-05T09:32:05.359', 'YEAR');
2015-01-01T00:00:00
> SELECT _FUNC_('2015-03-05T09:32:05.359', 'MM');
2015-03-01T00:00:00
> SELECT _FUNC_('2015-03-05T09:32:05.359', 'DD');
2015-03-05T00:00:00
> SELECT _FUNC_('2015-03-05T09:32:05.359', 'HOUR');
2015-03-05T09:00:00
""",
since = "2.3.0")
// scalastyle:on line.size.limit
case class TruncTimestamp(
format: Expression,
timestamp: Expression,
timeZoneId: Option[String] = None)
extends TruncInstant with TimeZoneAwareExpression {
override def left: Expression = format
override def right: Expression = timestamp

override def inputTypes: Seq[AbstractDataType] = Seq(StringType, TimestampType)
override def dataType: TimestampType = TimestampType
override def prettyName: String = "date_trunc"
override val instant = timestamp
override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
copy(timeZoneId = Option(timeZoneId))

def this(format: Expression, timestamp: Expression) = this(format, timestamp, None)

override def eval(input: InternalRow): Any = {
evalHelper(input, maxLevel = DateTimeUtils.TRUNC_TO_SECOND) { (t: Any, level: Int) =>
DateTimeUtils.truncTimestamp(t.asInstanceOf[Long], level, timeZone)
}
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val tz = ctx.addReferenceObj("timeZone", timeZone)
codeGenHelper(ctx, ev, maxLevel = DateTimeUtils.TRUNC_TO_SECOND, true) {
(date: String, fmt: String) =>
s"truncTimestamp($date, $fmt, $tz);"
}
}
}

/**
* Returns the number of days from startDate to endDate.
*/
Expand Down
Loading