Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -888,6 +888,36 @@ def months_between(date1, date2):
return Column(sc._jvm.functions.months_between(_to_java_column(date1), _to_java_column(date2)))


@since(1.5)
def to_date(col):
"""
Converts the column of StringType or TimestampType into DateType.

>>> df = sqlContext.createDataFrame([('1997-02-28 10:30:00',)], ['t'])
>>> df.select(to_date(df.t).alias('date')).collect()
[Row(date=datetime.date(1997, 2, 28))]
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.to_date(_to_java_column(col)))


@since(1.5)
def trunc(date, format):
"""
Returns date truncated to the unit specified by the format.

:param format: 'year', 'YYYY', 'yy' or 'month', 'mon', 'mm'

>>> df = sqlContext.createDataFrame([('1997-02-28',)], ['d'])
>>> df.select(trunc(df.d, 'year').alias('year')).collect()
[Row(year=datetime.date(1997, 1, 1))]
>>> df.select(trunc(df.d, 'mon').alias('month')).collect()
[Row(month=datetime.date(1997, 2, 1))]
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.trunc(_to_java_column(date), format))


@since(1.5)
def size(col):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,8 @@ object FunctionRegistry {
expression[NextDay]("next_day"),
expression[Quarter]("quarter"),
expression[Second]("second"),
expression[ToDate]("to_date"),
expression[TruncDate]("trunc"),
expression[UnixTimestamp]("unix_timestamp"),
expression[WeekOfYear]("weekofyear"),
expression[Year]("year"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -507,7 +507,6 @@ case class FromUnixTime(sec: Expression, format: Expression)
})
}
}

}

/**
Expand Down Expand Up @@ -696,3 +695,90 @@ case class MonthsBetween(date1: Expression, date2: Expression)
})
}
}

/**
* Returns the date part of a timestamp or string.
*/
case class ToDate(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {

// Implicit casting of spark will accept string in both date and timestamp format, as
// well as TimestampType.
override def inputTypes: Seq[AbstractDataType] = Seq(DateType)

override def dataType: DataType = DateType

override def eval(input: InternalRow): Any = child.eval(input)

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
defineCodeGen(ctx, ev, d => d)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tiny optimization - this can just call child.genCode

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

genCode is protected.

}
}

/*
* Returns date truncated to the unit specified by the format.
*/
case class TruncDate(date: Expression, format: Expression)
extends BinaryExpression with ImplicitCastInputTypes {
override def left: Expression = date
override def right: Expression = format

override def inputTypes: Seq[AbstractDataType] = Seq(DateType, StringType)
override def dataType: DataType = DateType
override def prettyName: String = "trunc"

lazy val minItemConst = DateTimeUtils.parseTruncLevel(format.eval().asInstanceOf[UTF8String])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mark this as private, and maybe rename it to truncationLevel?


override def eval(input: InternalRow): Any = {
val minItem = if (format.foldable) {
minItemConst
} else {
DateTimeUtils.parseTruncLevel(format.eval().asInstanceOf[UTF8String])
}
if (minItem == -1) {
// unknown format
null
} else {
val d = date.eval(input)
if (d == null) {
null
} else {
DateTimeUtils.truncDate(d.asInstanceOf[Int], minItem)
}
}
}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")

if (format.foldable) {
if (minItemConst == -1) {
s"""
boolean ${ev.isNull} = true;
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
"""
} else {
val d = date.gen(ctx)
s"""
${d.code}
boolean ${ev.isNull} = ${d.isNull};
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
if (!${ev.isNull}) {
${ev.primitive} = $dtu.truncDate(${d.primitive}, $minItemConst);
}
"""
}
} else {
nullSafeCodeGen(ctx, ev, (dateVal, fmt) => {
val form = ctx.freshName("form")
s"""
int $form = $dtu.parseTruncLevel($fmt);
if ($form == -1) {
${ev.isNull} = true;
} else {
${ev.primitive} = $dtu.truncDate($dateVal, $form);
}
"""
})
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -779,4 +779,38 @@ object DateTimeUtils {
}
date + (lastDayOfMonthInYear - dayInYear)
}

private val TRUNC_TO_YEAR = 1
private val TRUNC_TO_MONTH = 2
private val TRUNC_INVALID = -1

/**
* Returns the trunc date from original date and trunc level.
* Trunc level should be generated using `parseTruncLevel()`, should only be 1 or 2.
*/
def truncDate(d: Int, level: Int): Int = {
if (level == TRUNC_TO_YEAR) {
d - DateTimeUtils.getDayInYear(d) + 1
} else if (level == TRUNC_TO_MONTH) {
d - DateTimeUtils.getDayOfMonth(d) + 1
} else {
throw new Exception(s"Invalid trunc level: $level")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sys.error("...")

and add a comment explaining this should never be hit because trunc level is internally generated.

}
}

/**
* Returns the truncate level, could be TRUNC_YEAR, TRUNC_MONTH, or TRUNC_INVALID,
* TRUNC_INVALID means unsupported truncate level.
*/
def parseTruncLevel(format: UTF8String): Int = {
if (format == null) {
TRUNC_INVALID
} else {
format.toString.toUpperCase match {
case "YEAR" | "YYYY" | "YY" => TRUNC_TO_YEAR
case "MON" | "MONTH" | "MM" => TRUNC_TO_MONTH
case _ => TRUNC_INVALID
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,34 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
NextDay(Literal(Date.valueOf("2015-07-23")), Literal.create(null, StringType)), null)
}

test("function to_date") {
checkEvaluation(
ToDate(Literal(Date.valueOf("2015-07-22"))),
DateTimeUtils.fromJavaDate(Date.valueOf("2015-07-22")))
checkEvaluation(ToDate(Literal.create(null, DateType)), null)
}

test("function trunc") {
def testTrunc(input: Date, fmt: String, expected: Date): Unit = {
checkEvaluation(TruncDate(Literal.create(input, DateType), Literal.create(fmt, StringType)),
expected)
checkEvaluation(
TruncDate(Literal.create(input, DateType), NonFoldableLiteral.create(fmt, StringType)),
expected)
}
val date = Date.valueOf("2015-07-22")
Seq("yyyy", "YYYY", "year", "YEAR", "yy", "YY").foreach{ fmt =>
testTrunc(date, fmt, Date.valueOf("2015-01-01"))
}
Seq("month", "MONTH", "mon", "MON", "mm", "MM").foreach { fmt =>
testTrunc(date, fmt, Date.valueOf("2015-07-01"))
}
testTrunc(date, "DD", null)
testTrunc(date, null, null)
testTrunc(null, "MON", null)
testTrunc(null, null, null)
}

test("from_unixtime") {
val sdf1 = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss")
val fmt2 = "yyyy-MM-dd HH:mm:ss.SSS"
Expand Down Expand Up @@ -405,5 +433,4 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(
UnixTimestamp(Literal("2015-07-24"), Literal("not a valid format")), null)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -47,4 +47,8 @@ object NonFoldableLiteral {
val lit = Literal(value)
NonFoldableLiteral(lit.value, lit.dataType)
}
def create(value: Any, dataType: DataType): NonFoldableLiteral = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is not needed, is it? NonFoldableLiteral already has this if you don't define it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the value should be casted into Catalyst type, so it's needed.

val lit = Literal.create(value, dataType)
NonFoldableLiteral(lit.value, lit.dataType)
}
}
16 changes: 16 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2181,6 +2181,22 @@ object functions {
*/
def unix_timestamp(s: Column, p: String): Column = UnixTimestamp(s.expr, Literal(p))

/*
* Converts the column into DateType.
*
* @group datetime_funcs
* @since 1.5.0
*/
def to_date(e: Column): Column = ToDate(e.expr)

/**
* Returns date truncated to the unit specified by the format.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here we should document what the accepted values are for format, and give an example. Otherwise it is very hard for users to know what this function actually does.

*
* @group datetime_funcs
* @since 1.5.0
*/
def trunc(date: Column, format: String): Column = TruncDate(date.expr, Literal(format))

//////////////////////////////////////////////////////////////////////////////////////////////
// Collection functions
//////////////////////////////////////////////////////////////////////////////////////////////
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,50 @@ class DateFunctionsSuite extends QueryTest {
Seq(Row(Date.valueOf("2015-07-30")), Row(Date.valueOf("2015-07-30"))))
}

test("function to_date") {
val d1 = Date.valueOf("2015-07-22")
val d2 = Date.valueOf("2015-07-01")
val t1 = Timestamp.valueOf("2015-07-22 10:00:00")
val t2 = Timestamp.valueOf("2014-12-31 23:59:59")
val s1 = "2015-07-22 10:00:00"
val s2 = "2014-12-31"
val df = Seq((d1, t1, s1), (d2, t2, s2)).toDF("d", "t", "s")

checkAnswer(
df.select(to_date(col("t"))),
Seq(Row(Date.valueOf("2015-07-22")), Row(Date.valueOf("2014-12-31"))))
checkAnswer(
df.select(to_date(col("d"))),
Seq(Row(Date.valueOf("2015-07-22")), Row(Date.valueOf("2015-07-01"))))
checkAnswer(
df.select(to_date(col("s"))),
Seq(Row(Date.valueOf("2015-07-22")), Row(Date.valueOf("2014-12-31"))))

checkAnswer(
df.selectExpr("to_date(t)"),
Seq(Row(Date.valueOf("2015-07-22")), Row(Date.valueOf("2014-12-31"))))
checkAnswer(
df.selectExpr("to_date(d)"),
Seq(Row(Date.valueOf("2015-07-22")), Row(Date.valueOf("2015-07-01"))))
checkAnswer(
df.selectExpr("to_date(s)"),
Seq(Row(Date.valueOf("2015-07-22")), Row(Date.valueOf("2014-12-31"))))
}

test("function trunc") {
val df = Seq(
(1, Timestamp.valueOf("2015-07-22 10:00:00")),
(2, Timestamp.valueOf("2014-12-31 00:00:00"))).toDF("i", "t")

checkAnswer(
df.select(trunc(col("t"), "YY")),
Seq(Row(Date.valueOf("2015-01-01")), Row(Date.valueOf("2014-01-01"))))

checkAnswer(
df.selectExpr("trunc(t, 'Month')"),
Seq(Row(Date.valueOf("2015-07-01")), Row(Date.valueOf("2014-12-01"))))
}

test("from_unixtime") {
val sdf1 = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss")
val fmt2 = "yyyy-MM-dd HH:mm:ss.SSS"
Expand Down