Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -974,7 +974,8 @@ def trunc(date, format):
"""
Returns date truncated to the unit specified by the format.

:param format: 'year', 'YYYY', 'yy' or 'month', 'mon', 'mm'
:param format: 'year', 'YYYY', 'yy', 'month', 'mon', 'mm', 'day', 'dd', 'hour', 'hh', 'mi',
or 'sec', 'ss'.

>>> df = spark.createDataFrame([('1997-02-28',)], ['d'])
>>> df.select(trunc(df.d, 'year').alias('year')).collect()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ object FunctionRegistry {
expression[CurrentDate]("current_date"),
expression[CurrentTimestamp]("current_timestamp"),
expression[DateDiff]("datediff"),
expression[DateAdd]("date_add"),
expression[AddDays]("date_add"),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I may understand why hive call it date_add...

Other databases also have date_add, but it takes interval type, which means it can handle days or months or both.

To keep hive-compatibility, we should keep it as it was, but rename the expression class name LGTM

expression[DateFormatClass]("date_format"),
expression[DateSub]("date_sub"),
expression[DayOfMonth]("day"),
Expand All @@ -342,7 +342,7 @@ object FunctionRegistry {
expression[ToDate]("to_date"),
expression[ToUnixTimestamp]("to_unix_timestamp"),
expression[ToUTCTimestamp]("to_utc_timestamp"),
expression[TruncDate]("trunc"),
expression[TruncInstant]("trunc"),
expression[UnixTimestamp]("unix_timestamp"),
expression[WeekOfYear]("weekofyear"),
expression[Year]("year"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,58 +72,64 @@ case class CurrentTimestamp() extends LeafExpression with CodegenFallback {
}

/**
* Adds a number of days to startdate.
* The base for addition/subtraction for days.
Copy link
Member Author

@HyukjinKwon HyukjinKwon Sep 8, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it seems it shows the diff a bit messy here. I made a common parent for AddDays and SubDays named AddDaysBase.

*/
@ExpressionDescription(
usage = "_FUNC_(start_date, num_days) - Returns the date that is num_days after start_date.",
extended = "> SELECT _FUNC_('2016-07-30', 1);\n '2016-07-31'")
case class DateAdd(startDate: Expression, days: Expression)
abstract class AddDaysBase(instant: Expression, days: Expression)
extends BinaryExpression with ImplicitCastInputTypes {

override def left: Expression = startDate
override def left: Expression = instant
override def right: Expression = days

override def inputTypes: Seq[AbstractDataType] = Seq(DateType, IntegerType)
override def inputTypes: Seq[AbstractDataType] =
Seq(TypeCollection(DateType, TimestampType), IntegerType)

override def dataType: DataType = DateType
override def dataType: DataType = instant.dataType

// 1 for addition, -1 for subtraction
def signModifier: Int

override def nullSafeEval(start: Any, d: Any): Any = {
start.asInstanceOf[Int] + d.asInstanceOf[Int]
override def nullSafeEval(start: Any, days: Any): Any = {
(instant.dataType, start, days) match {
case (_: DateType, startDate: Int, days: Int) =>
startDate + (signModifier * days)
case (_: TimestampType, startTimestamp: Long, days: Int) =>
DateTimeUtils.timestampAddDays(startTimestamp, signModifier * days)
}
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, (sd, d) => {
s"""${ev.value} = $sd + $d;"""
})
instant.dataType match {
case DateType =>
nullSafeCodeGen(ctx, ev, (sd, d) => s"""${ev.value} = $sd + ($signModifier * $d);""")
case TimestampType =>
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
defineCodeGen(ctx, ev, (sd, d) => s"""$dtu.timestampAddDays($sd, ($signModifier * $d))""")
}
}

override def prettyName: String = "date_add"
}

/**
* Subtracts a number of days to startdate.
* Adds a number of days to date/timestamp.
*/
@ExpressionDescription(
usage = "_FUNC_(start_date, num_days) - Returns the date that is num_days before start_date.",
extended = "> SELECT _FUNC_('2016-07-30', 1);\n '2016-07-29'")
case class DateSub(startDate: Expression, days: Expression)
extends BinaryExpression with ImplicitCastInputTypes {
override def left: Expression = startDate
override def right: Expression = days
usage = "_FUNC_(instant, num_days) - Returns the date/timestamp that is num_days after instant.",
extended = "> SELECT _FUNC_('2016-07-30', 1);\n '2016-07-31'")
case class AddDays(instant: Expression, days: Expression) extends AddDaysBase(instant, days) {

override def inputTypes: Seq[AbstractDataType] = Seq(DateType, IntegerType)
override def signModifier: Int = 1

override def dataType: DataType = DateType
override def prettyName: String = "date_add"
}

override def nullSafeEval(start: Any, d: Any): Any = {
start.asInstanceOf[Int] - d.asInstanceOf[Int]
}
/**
* Subtracts a number of days to date/timestamp.
*/
@ExpressionDescription(
usage = "_FUNC_(instant, num_days) - Returns the date/timestamp that is num_days before instant.",
extended = "> SELECT _FUNC_('2016-07-30', 1);\n '2016-07-29'")
case class DateSub(instant: Expression, days: Expression) extends AddDaysBase(instant, days) {

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, (sd, d) => {
s"""${ev.value} = $sd - $d;"""
})
}
override def signModifier: Int = -1

override def prettyName: String = "date_sub"
}
Expand Down Expand Up @@ -781,29 +787,38 @@ case class TimeSub(start: Expression, interval: Expression)
}

/**
* Returns the date that is num_months after start_date.
* Returns the date/timestamp that is num_months after instant.
*/
// scalastyle:off line.size.limit
@ExpressionDescription(
usage = "_FUNC_(start_date, num_months) - Returns the date that is num_months after start_date.",
usage = "_FUNC_(instant, num_months) - Returns the date/timestamp that is num_months after instant.",
extended = "> SELECT _FUNC_('2016-08-31', 1);\n '2016-09-30'")
case class AddMonths(startDate: Expression, numMonths: Expression)
// scalastyle:on line.size.limit
case class AddMonths(instant: Expression, numMonths: Expression)
extends BinaryExpression with ImplicitCastInputTypes {

override def left: Expression = startDate
override def left: Expression = instant
override def right: Expression = numMonths

override def inputTypes: Seq[AbstractDataType] = Seq(DateType, IntegerType)
override def inputTypes: Seq[AbstractDataType] =
Seq(TypeCollection(DateType, TimestampType), IntegerType)

override def dataType: DataType = DateType
override def dataType: DataType = instant.dataType

override def nullSafeEval(start: Any, months: Any): Any = {
DateTimeUtils.dateAddMonths(start.asInstanceOf[Int], months.asInstanceOf[Int])
(instant.dataType, start, months) match {
case (_: DateType, startDate: Int, months: Int) =>
DateTimeUtils.dateAddMonths(startDate, months)
case (_: TimestampType, startTimestamp: Long, months: Int) =>
DateTimeUtils.timestampAddInterval(startTimestamp, months, 0)
}
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
defineCodeGen(ctx, ev, (sd, m) => {
s"""$dtu.dateAddMonths($sd, $m)"""
defineCodeGen(ctx, ev, (sd, m) => instant.dataType match {
case DateType => s"""$dtu.dateAddMonths($sd, $m)"""
case TimestampType => s"""$dtu.timestampAddInterval($sd, $m, 0)"""
})
}

Expand Down Expand Up @@ -916,21 +931,26 @@ case class ToDate(child: Expression) extends UnaryExpression with ImplicitCastIn
}

/**
* Returns date truncated to the unit specified by the format.
* Returns timestamp truncated to the unit specified by the format.
*/
// scalastyle:off line.size.limit
@ExpressionDescription(
usage = "_FUNC_(date, fmt) - Returns returns date with the time portion of the day truncated to the unit specified by the format model fmt.",
extended = "> SELECT _FUNC_('2009-02-12', 'MM')\n '2009-02-01'\n> SELECT _FUNC_('2015-10-27', 'YEAR');\n '2015-01-01'")
usage = "_FUNC_(instant, fmt) - Returns returns date/timestamp with the time portion truncated to the unit specified by the format model fmt.",
extended = "> SELECT _FUNC_('2009-02-12', 'MM')\n '2009-02-01 00:00:00'\n> SELECT _FUNC_('2015-10-27', 'YEAR');\n '2015-01-01 00:00:00'")
// scalastyle:on line.size.limit
case class TruncDate(date: Expression, format: Expression)
case class TruncInstant(instant: Expression, format: Expression)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This auto casts. I think we are still breaking API here if a user passes a Timestamp. In the old situation the user would always get a Date, and now he gets a Date or Timestamp based on the input type. So I think we need to split this into two expressions.

Copy link
Member Author

@HyukjinKwon HyukjinKwon Oct 13, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I misunderstood the first comment.. Will make two expressions. Thanks!

Copy link
Member Author

@HyukjinKwon HyukjinKwon Oct 13, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hvanhovell ur.. actually, should I split other functions I corrected here as well here? DateAdd, DateSub and etc. also seem having the same problems.

extends BinaryExpression with ImplicitCastInputTypes {
override def left: Expression = date

override def left: Expression = instant
override def right: Expression = format

override def inputTypes: Seq[AbstractDataType] = Seq(DateType, StringType)
override def dataType: DataType = DateType
override def inputTypes: Seq[AbstractDataType] =
Seq(TypeCollection(DateType, TimestampType), StringType)

override def dataType: DataType = instant.dataType

override def nullable: Boolean = true

override def prettyName: String = "trunc"

private lazy val truncLevel: Int =
Expand All @@ -946,11 +966,12 @@ case class TruncDate(date: Expression, format: Expression)
// unknown format
null
} else {
val d = date.eval(input)
if (d == null) {
null
} else {
DateTimeUtils.truncDate(d.asInstanceOf[Int], level)
(instant.dataType, instant.eval(input)) match {
case (_: DateType, date: Int) =>
DateTimeUtils.truncateInstant(date, level)
case (_: TimestampType, timestamp: Long) =>
DateTimeUtils.truncateInstant(timestamp, level)
case (_, null) => null
}
}
}
Expand All @@ -964,13 +985,13 @@ case class TruncDate(date: Expression, format: Expression)
boolean ${ev.isNull} = true;
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};""")
} else {
val d = date.genCode(ctx)
val ist = instant.genCode(ctx)
ev.copy(code = s"""
${d.code}
boolean ${ev.isNull} = ${d.isNull};
${ist.code}
boolean ${ev.isNull} = ${ist.isNull};
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
if (!${ev.isNull}) {
${ev.value} = $dtu.truncDate(${d.value}, $truncLevel);
${ev.value} = $dtu.truncateInstant(${ist.value}, $truncLevel);
}""")
}
} else {
Expand All @@ -981,7 +1002,7 @@ case class TruncDate(date: Expression, format: Expression)
if ($form == -1) {
${ev.isNull} = true;
} else {
${ev.value} = $dtu.truncDate($dateVal, $form);
${ev.value} = $dtu.truncateInstant($dateVal, $form);
}
"""
})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,16 @@ object DateTimeUtils {
// see http://stackoverflow.com/questions/466321/convert-unix-timestamp-to-julian
// it's 2440587.5, rounding up to compatible with Hive
final val JULIAN_DAY_OF_EPOCH = 2440588

final val SECONDS_PER_DAY = 60 * 60 * 24L
final val MICROS_PER_SECOND = 1000L * 1000L
final val NANOS_PER_SECOND = MICROS_PER_SECOND * 1000L
final val MICROS_PER_DAY = MICROS_PER_SECOND * SECONDS_PER_DAY

final val MILLIS_PER_DAY = SECONDS_PER_DAY * 1000L
final val MILLIS_PER_SECOND = 1000L
final val MILLIS_PER_MINUTE = 60L * MILLIS_PER_SECOND
final val MILLIS_PER_HOUR = 60L * 60L * MILLIS_PER_SECOND
final val MILLIS_PER_DAY = SECONDS_PER_DAY * MILLIS_PER_SECOND

final val MICROS_PER_SECOND = MILLIS_PER_SECOND * 1000L
final val MICROS_PER_DAY = SECONDS_PER_DAY * MICROS_PER_SECOND

// number of days in 400 years
final val daysIn400Years: Int = 146097
Expand Down Expand Up @@ -747,6 +751,14 @@ object DateTimeUtils {
daysToMillis(newDays) * 1000L + start - daysToMillis(days) * 1000L + microseconds
}

/**
* Add timestamp and days interval.
* Returns a timestamp value, expressed in microseconds since 1.1.1970 00:00:00.
*/
def timestampAddDays(start: SQLTimestamp, days: Int): SQLTimestamp = {
start + days * MICROS_PER_DAY
}

/**
* Returns number of months between time1 and time2. time1 and time2 are expressed in
* microseconds since 1.1.1970.
Expand Down Expand Up @@ -817,13 +829,29 @@ object DateTimeUtils {

private val TRUNC_TO_YEAR = 1
private val TRUNC_TO_MONTH = 2
private val TRUNC_TO_DAY = 3
private val TRUNC_TO_HOUR = 4
private val TRUNC_TO_MINUTE = 5
private val TRUNC_TO_SECOND = 6
private val TRUNC_INVALID = -1

/**
* Returns the trunc timestamp from original timestamp and trunc level.
* Trunc level should be generated using `parseTruncLevel()`, should only be 1 - 6.
*/
def truncateInstant(ts: SQLTimestamp, level: Int): SQLTimestamp = {
if (level == TRUNC_TO_YEAR || level == TRUNC_TO_MONTH) {
daysToMillis(truncateInstant(millisToDays(ts / 1000L), level)) * 1000L
} else {
truncateTime(ts, level)
}
}

/**
* Returns the trunc date from original date and trunc level.
* Trunc level should be generated using `parseTruncLevel()`, should only be 1 or 2.
*/
def truncDate(d: SQLDate, level: Int): SQLDate = {
def truncateInstant(d: SQLDate, level: Int): SQLDate = {
if (level == TRUNC_TO_YEAR) {
d - DateTimeUtils.getDayInYear(d) + 1
} else if (level == TRUNC_TO_MONTH) {
Expand All @@ -834,8 +862,28 @@ object DateTimeUtils {
}
}

private def truncateTime(ts: SQLTimestamp, level: Int): SQLTimestamp = {
val unitInMillis = level match {
case TRUNC_TO_DAY => MILLIS_PER_DAY
case TRUNC_TO_HOUR => MILLIS_PER_HOUR
case TRUNC_TO_MINUTE => MILLIS_PER_MINUTE
case TRUNC_TO_SECOND => MILLIS_PER_SECOND
case _ =>
// caller make sure that this should never be reached
sys.error(s"Invalid trunc level: $level")
}

val millisUtc = ts / 1000L
val millisLocal = millisUtc + threadLocalLocalTimeZone.get().getOffset(millisUtc)
val truncatedMillisLocal = millisLocal - millisLocal % unitInMillis
val offset = getOffsetFromLocalMillis(truncatedMillisLocal, threadLocalLocalTimeZone.get())
val truncatedMillis = truncatedMillisLocal - offset
truncatedMillis * 1000L
}

/**
* Returns the truncate level, could be TRUNC_YEAR, TRUNC_MONTH, or TRUNC_INVALID,
* Returns the truncate level, could be TRUNC_YEAR, TRUNC_MONTH, TRUNC_TO_DAY, TRUNC_TO_HOUR,
* TRUNC_TO_MINUTE, TRUNC_TO_SECOND or TRUNC_INVALID.
* TRUNC_INVALID means unsupported truncate level.
*/
def parseTruncLevel(format: UTF8String): Int = {
Expand All @@ -845,6 +893,10 @@ object DateTimeUtils {
format.toString.toUpperCase match {
case "YEAR" | "YYYY" | "YY" => TRUNC_TO_YEAR
case "MON" | "MONTH" | "MM" => TRUNC_TO_MONTH
case "DAY" | "DD" => TRUNC_TO_DAY
case "HOUR" | "HH" => TRUNC_TO_HOUR
case "MI" => TRUNC_TO_MINUTE
case "SEC" | "SS" => TRUNC_TO_SECOND
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

case _ => TRUNC_INVALID
}
}
Expand Down
Loading