Skip to content

Commit 0521149

Browse files
committed
Changes to make SparkR DataFrame dplyr friendly.
Changes include 1. Rename sortDF to arrange 2. Add new aliases `group_by` and `sample_frac`, `summarize` 3. Add more user friendly column addition (mutate), rename 4. Support mean as an alias for avg in Scala and also support n_distinct, n as in dplyr
1 parent 1712a7c commit 0521149

File tree

8 files changed

+248
-27
lines changed

8 files changed

+248
-27
lines changed

R/pkg/NAMESPACE

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@ export("print.jobj")
99

1010
exportClasses("DataFrame")
1111

12-
exportMethods("cache",
12+
exportMethods("arrange",
13+
"cache",
1314
"collect",
1415
"columns",
1516
"count",
@@ -20,6 +21,7 @@ exportMethods("cache",
2021
"explain",
2122
"filter",
2223
"first",
24+
"group_by",
2325
"groupBy",
2426
"head",
2527
"insertInto",
@@ -29,12 +31,15 @@ exportMethods("cache",
2931
"length",
3032
"limit",
3133
"orderBy",
34+
"mutate",
3235
"names",
3336
"persist",
3437
"printSchema",
3538
"registerTempTable",
39+
"rename",
3640
"repartition",
3741
"sampleDF",
42+
"sample_frac",
3843
"saveAsParquetFile",
3944
"saveAsTable",
4045
"saveDF",
@@ -43,7 +48,7 @@ exportMethods("cache",
4348
"selectExpr",
4449
"show",
4550
"showDF",
46-
"sortDF",
51+
"summarize",
4752
"take",
4853
"unionAll",
4954
"unpersist",
@@ -73,6 +78,8 @@ exportMethods("abs",
7378
"max",
7479
"mean",
7580
"min",
81+
"n",
82+
"n_distinct",
7683
"rlike",
7784
"sqrt",
7885
"startsWith",

R/pkg/R/DataFrame.R

Lines changed: 116 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -477,6 +477,7 @@ setMethod("distinct",
477477
#' @param withReplacement Sampling with replacement or not
478478
#' @param fraction The (rough) sample target fraction
479479
#' @rdname sampleDF
480+
#' @alias sample_frac
480481
#' @export
481482
#' @examples
482483
#'\dontrun{
@@ -498,6 +499,15 @@ setMethod("sampleDF",
498499
dataFrame(sdf)
499500
})
500501

502+
#' @rdname sampleDF
503+
#' @alias sampleDF
504+
setMethod("sample_frac",
505+
signature(x = "DataFrame", withReplacement = "logical",
506+
fraction = "numeric"),
507+
function(x, withReplacement, fraction) {
508+
sampleDF(x, withReplacement, fraction)
509+
})
510+
501511
#' Count
502512
#'
503513
#' Returns the number of rows in a DataFrame
@@ -679,7 +689,8 @@ setMethod("toRDD",
679689
#' @param x a DataFrame
680690
#' @return a GroupedData
681691
#' @seealso GroupedData
682-
#' @rdname DataFrame
692+
#' @alias group_by
693+
#' @rdname groupBy
683694
#' @export
684695
#' @examples
685696
#' \dontrun{
@@ -702,18 +713,35 @@ setMethod("groupBy",
702713
groupedData(sgd)
703714
})
704715

705-
#' Agg
716+
#' @rdname groupBy
717+
#' @aliases group_by
718+
setMethod("group_by",
719+
signature(x = "DataFrame"),
720+
function(x, ...) {
721+
groupBy(x, ...)
722+
})
723+
724+
#' Summarize data across columns
706725
#'
707726
#' Compute aggregates by specifying a list of columns
708727
#'
709728
#' @rdname DataFrame
729+
#' @alias summarize
710730
#' @export
711731
setMethod("agg",
712732
signature(x = "DataFrame"),
713733
function(x, ...) {
714734
agg(groupBy(x), ...)
715735
})
716736

737+
#' @rdname DataFrame
738+
#' @alias agg
739+
setMethod("summarize",
740+
signature(x = "DataFrame"),
741+
function(x, ...) {
742+
agg(x, ...)
743+
})
744+
717745

718746
############################## RDD Map Functions ##################################
719747
# All of the following functions mirror the existing RDD map functions, #
@@ -881,7 +909,7 @@ setMethod("select",
881909
signature(x = "DataFrame", col = "list"),
882910
function(x, col) {
883911
cols <- lapply(col, function(c) {
884-
if (class(c)== "Column") {
912+
if (class(c) == "Column") {
885913
c@jc
886914
} else {
887915
col(c)@jc
@@ -941,6 +969,42 @@ setMethod("withColumn",
941969
select(x, x$"*", alias(col, colName))
942970
})
943971

972+
#' Mutate
973+
#'
974+
#' Return a new DataFrame with the specified columns added.
975+
#'
976+
#' @param x A DataFrame
977+
#' @param col a named argument of the form name = col
978+
#' @return A new DataFrame with the new columns added.
979+
#' @rdname withColumn
980+
#' @alias withColumn
981+
#' @export
982+
#' @examples
983+
#'\dontrun{
984+
#' sc <- sparkR.init()
985+
#' sqlCtx <- sparkRSQL.init(sc)
986+
#' path <- "path/to/file.json"
987+
#' df <- jsonFile(sqlCtx, path)
988+
#' newDF <- mutate(df, newCol = df$col1 * 5, newCol2 = df$col1 * 2)
989+
#' names(newDF) # Will contain newCol, newCol2
990+
#' }
991+
setMethod("mutate",
992+
signature(x = "DataFrame"),
993+
function(x, ...) {
994+
cols <- list(...)
995+
stopifnot(length(cols) > 0)
996+
stopifnot(class(cols[[1]]) == "Column")
997+
ns <- names(cols)
998+
if (!is.null(ns)) {
999+
for (n in ns) {
1000+
if (n != "") {
1001+
cols[[n]] <- alias(cols[[n]], n)
1002+
}
1003+
}
1004+
}
1005+
do.call(select, c(x, x$"*", cols))
1006+
})
1007+
9441008
#' WithColumnRenamed
9451009
#'
9461010
#' Rename an existing column in a DataFrame.
@@ -972,29 +1036,67 @@ setMethod("withColumnRenamed",
9721036
select(x, cols)
9731037
})
9741038

1039+
#' Rename
1040+
#'
1041+
#' Rename an existing column in a DataFrame.
1042+
#'
1043+
#' @param x A DataFrame
1044+
#' @param newCol A named pair of the form new_column_name = existing_column
1045+
#' @return A DataFrame with the column name changed.
1046+
#' @rdname withColumnRenamed
1047+
#' @alias withColumnRenamed
1048+
#' @export
1049+
#' @examples
1050+
#'\dontrun{
1051+
#' sc <- sparkR.init()
1052+
#' sqlCtx <- sparkRSQL.init(sc)
1053+
#' path <- "path/to/file.json"
1054+
#' df <- jsonFile(sqlCtx, path)
1055+
#' newDF <- rename(df, col1 = df$newCol1)
1056+
#' }
1057+
setMethod("rename",
1058+
signature(x = "DataFrame"),
1059+
function(x, ...) {
1060+
renameCols <- list(...)
1061+
stopifnot(length(renameCols) > 0)
1062+
stopifnot(class(renameCols[[1]]) == "Column")
1063+
newNames <- names(renameCols)
1064+
oldNames <- lapply(renameCols, function(col) {
1065+
callJMethod(col@jc, "toString")
1066+
})
1067+
cols <- lapply(columns(x), function(c) {
1068+
if (c %in% oldNames) {
1069+
alias(col(c), newNames[[match(c, oldNames)]])
1070+
} else {
1071+
col(c)
1072+
}
1073+
})
1074+
select(x, cols)
1075+
})
1076+
9751077
setClassUnion("characterOrColumn", c("character", "Column"))
9761078

977-
#' SortDF
1079+
#' Arrange
9781080
#'
9791081
#' Sort a DataFrame by the specified column(s).
9801082
#'
9811083
#' @param x A DataFrame to be sorted.
9821084
#' @param col Either a Column object or character vector indicating the field to sort on
9831085
#' @param ... Additional sorting fields
9841086
#' @return A DataFrame where all elements are sorted.
985-
#' @rdname sortDF
1087+
#' @rdname arrange
9861088
#' @export
9871089
#' @examples
9881090
#'\dontrun{
9891091
#' sc <- sparkR.init()
9901092
#' sqlCtx <- sparkRSQL.init(sc)
9911093
#' path <- "path/to/file.json"
9921094
#' df <- jsonFile(sqlCtx, path)
993-
#' sortDF(df, df$col1)
994-
#' sortDF(df, "col1")
995-
#' sortDF(df, asc(df$col1), desc(abs(df$col2)))
1095+
#' arrange(df, df$col1)
1096+
#' arrange(df, "col1")
1097+
#' arrange(df, asc(df$col1), desc(abs(df$col2)))
9961098
#' }
997-
setMethod("sortDF",
1099+
setMethod("arrange",
9981100
signature(x = "DataFrame", col = "characterOrColumn"),
9991101
function(x, col, ...) {
10001102
if (class(col) == "character") {
@@ -1008,20 +1110,21 @@ setMethod("sortDF",
10081110
dataFrame(sdf)
10091111
})
10101112

1011-
#' @rdname sortDF
1113+
#' @rdname arrange
1114+
#' @aliases orderBy,DataFrame,function-method
10121115
#' @export
10131116
setMethod("orderBy",
10141117
signature(x = "DataFrame", col = "characterOrColumn"),
10151118
function(x, col) {
1016-
sortDF(x, col)
1119+
arrange(x, col)
10171120
})
10181121

10191122
#' Filter
10201123
#'
10211124
#' Filter the rows of a DataFrame according to a given condition.
10221125
#'
10231126
#' @param x A DataFrame to be sorted.
1024-
#' @param condition The condition to sort on. This may either be a Column expression
1127+
#' @param condition The condition to filter on. This may either be a Column expression
10251128
#' or a string containing a SQL statement
10261129
#' @return A DataFrame containing only the rows that meet the condition.
10271130
#' @rdname filter
@@ -1101,6 +1204,7 @@ setMethod("join",
11011204
#'
11021205
#' Return a new DataFrame containing the union of rows in this DataFrame
11031206
#' and another DataFrame. This is equivalent to `UNION ALL` in SQL.
1207+
#' Note that this does not remove duplicate rows across the two DataFrames.
11041208
#'
11051209
#' @param x A Spark DataFrame
11061210
#' @param y A Spark DataFrame

R/pkg/R/column.R

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,8 @@ createMethods()
130130

131131
#' alias
132132
#'
133+
#' @rdname column
134+
#'
133135
#' Set a new name for a column
134136
setMethod("alias",
135137
signature(object = "Column"),
@@ -143,6 +145,8 @@ setMethod("alias",
143145

144146
#' An expression that returns a substring.
145147
#'
148+
#' @rdname column
149+
#'
146150
#' @param start starting position
147151
#' @param stop ending position
148152
setMethod("substr", signature(x = "Column"),
@@ -152,6 +156,9 @@ setMethod("substr", signature(x = "Column"),
152156
})
153157

154158
#' Casts the column to a different data type.
159+
#'
160+
#' @rdname column
161+
#'
155162
#' @examples
156163
#' \dontrun{
157164
#' cast(df$age, "string")
@@ -173,8 +180,9 @@ setMethod("cast",
173180

174181
#' Approx Count Distinct
175182
#'
176-
#' Returns the approximate number of distinct items in a group.
183+
#' @rdname column
177184
#'
185+
#' Returns the approximate number of distinct items in a group.
178186
setMethod("approxCountDistinct",
179187
signature(x = "Column"),
180188
function(x, rsd = 0.95) {
@@ -184,8 +192,9 @@ setMethod("approxCountDistinct",
184192

185193
#' Count Distinct
186194
#'
187-
#' returns the number of distinct items in a group.
195+
#' @rdname column
188196
#'
197+
#' returns the number of distinct items in a group.
189198
setMethod("countDistinct",
190199
signature(x = "Column"),
191200
function(x, ...) {
@@ -197,3 +206,18 @@ setMethod("countDistinct",
197206
column(jc)
198207
})
199208

209+
#' @rdname column
210+
#' @alias countDistinct
211+
setMethod("n_distinct",
212+
signature(x = "Column"),
213+
function(x, ...) {
214+
countDistinct(x, ...)
215+
})
216+
217+
#' @rdname column
218+
#' @alias count
219+
setMethod("n",
220+
signature(x = "Column"),
221+
function(x) {
222+
count(x)
223+
})

0 commit comments

Comments
 (0)