Skip to content

Commit 06da5fd

Browse files
felixcheungshivaram
authored andcommitted
[SPARK-9319][SPARKR] Add support for setting column names, types
Add support for for colnames, colnames<-, coltypes<- Also added tests for names, names<- which have no test previously. I merged with PR 8984 (coltypes). Clicked the wrong thing, crewed up the PR. Recreated it here. Was #9218 shivaram sun-rui Author: felixcheung <[email protected]> Closes #9654 from felixcheung/colnamescoltypes. (cherry picked from commit c793d2d) Signed-off-by: Shivaram Venkataraman <[email protected]>
1 parent 2503a43 commit 06da5fd

File tree

5 files changed

+185
-55
lines changed

5 files changed

+185
-55
lines changed

R/pkg/NAMESPACE

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,10 @@ exportMethods("arrange",
2727
"attach",
2828
"cache",
2929
"collect",
30+
"colnames",
31+
"colnames<-",
3032
"coltypes",
33+
"coltypes<-",
3134
"columns",
3235
"count",
3336
"cov",
@@ -56,6 +59,7 @@ exportMethods("arrange",
5659
"mutate",
5760
"na.omit",
5861
"names",
62+
"names<-",
5963
"ncol",
6064
"nrow",
6165
"orderBy",
@@ -276,4 +280,4 @@ export("structField",
276280
"structType",
277281
"structType.jobj",
278282
"structType.structField",
279-
"print.structType")
283+
"print.structType")

R/pkg/R/DataFrame.R

Lines changed: 117 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,7 @@ setMethod("dtypes",
254254
#' @family DataFrame functions
255255
#' @rdname columns
256256
#' @name columns
257+
257258
#' @export
258259
#' @examples
259260
#'\dontrun{
@@ -262,6 +263,7 @@ setMethod("dtypes",
262263
#' path <- "path/to/file.json"
263264
#' df <- jsonFile(sqlContext, path)
264265
#' columns(df)
266+
#' colnames(df)
265267
#'}
266268
setMethod("columns",
267269
signature(x = "DataFrame"),
@@ -290,6 +292,121 @@ setMethod("names<-",
290292
}
291293
})
292294

295+
#' @rdname columns
296+
#' @name colnames
297+
setMethod("colnames",
298+
signature(x = "DataFrame"),
299+
function(x) {
300+
columns(x)
301+
})
302+
303+
#' @rdname columns
304+
#' @name colnames<-
305+
setMethod("colnames<-",
306+
signature(x = "DataFrame", value = "character"),
307+
function(x, value) {
308+
sdf <- callJMethod(x@sdf, "toDF", as.list(value))
309+
dataFrame(sdf)
310+
})
311+
312+
#' coltypes
313+
#'
314+
#' Get column types of a DataFrame
315+
#'
316+
#' @param x A SparkSQL DataFrame
317+
#' @return value A character vector with the column types of the given DataFrame
318+
#' @rdname coltypes
319+
#' @name coltypes
320+
#' @family DataFrame functions
321+
#' @export
322+
#' @examples
323+
#'\dontrun{
324+
#' irisDF <- createDataFrame(sqlContext, iris)
325+
#' coltypes(irisDF)
326+
#'}
327+
setMethod("coltypes",
328+
signature(x = "DataFrame"),
329+
function(x) {
330+
# Get the data types of the DataFrame by invoking dtypes() function
331+
types <- sapply(dtypes(x), function(x) {x[[2]]})
332+
333+
# Map Spark data types into R's data types using DATA_TYPES environment
334+
rTypes <- sapply(types, USE.NAMES=F, FUN=function(x) {
335+
# Check for primitive types
336+
type <- PRIMITIVE_TYPES[[x]]
337+
338+
if (is.null(type)) {
339+
# Check for complex types
340+
for (t in names(COMPLEX_TYPES)) {
341+
if (substring(x, 1, nchar(t)) == t) {
342+
type <- COMPLEX_TYPES[[t]]
343+
break
344+
}
345+
}
346+
347+
if (is.null(type)) {
348+
stop(paste("Unsupported data type: ", x))
349+
}
350+
}
351+
type
352+
})
353+
354+
# Find which types don't have mapping to R
355+
naIndices <- which(is.na(rTypes))
356+
357+
# Assign the original scala data types to the unmatched ones
358+
rTypes[naIndices] <- types[naIndices]
359+
360+
rTypes
361+
})
362+
363+
#' coltypes
364+
#'
365+
#' Set the column types of a DataFrame.
366+
#'
367+
#' @param x A SparkSQL DataFrame
368+
#' @param value A character vector with the target column types for the given
369+
#' DataFrame. Column types can be one of integer, numeric/double, character, logical, or NA
370+
#' to keep that column as-is.
371+
#' @rdname coltypes
372+
#' @name coltypes<-
373+
#' @export
374+
#' @examples
375+
#'\dontrun{
376+
#' sc <- sparkR.init()
377+
#' sqlContext <- sparkRSQL.init(sc)
378+
#' path <- "path/to/file.json"
379+
#' df <- jsonFile(sqlContext, path)
380+
#' coltypes(df) <- c("character", "integer")
381+
#' coltypes(df) <- c(NA, "numeric")
382+
#'}
383+
setMethod("coltypes<-",
384+
signature(x = "DataFrame", value = "character"),
385+
function(x, value) {
386+
cols <- columns(x)
387+
ncols <- length(cols)
388+
if (length(value) == 0) {
389+
stop("Cannot set types of an empty DataFrame with no Column")
390+
}
391+
if (length(value) != ncols) {
392+
stop("Length of type vector should match the number of columns for DataFrame")
393+
}
394+
newCols <- lapply(seq_len(ncols), function(i) {
395+
col <- getColumn(x, cols[i])
396+
if (!is.na(value[i])) {
397+
stype <- rToSQLTypes[[value[i]]]
398+
if (is.null(stype)) {
399+
stop("Only atomic type is supported for column types")
400+
}
401+
cast(col, stype)
402+
} else {
403+
col
404+
}
405+
})
406+
nx <- select(x, newCols)
407+
dataFrame(nx@sdf)
408+
})
409+
293410
#' Register Temporary Table
294411
#'
295412
#' Registers a DataFrame as a Temporary Table in the SQLContext
@@ -2102,52 +2219,3 @@ setMethod("with",
21022219
newEnv <- assignNewEnv(data)
21032220
eval(substitute(expr), envir = newEnv, enclos = newEnv)
21042221
})
2105-
2106-
#' Returns the column types of a DataFrame.
2107-
#'
2108-
#' @name coltypes
2109-
#' @title Get column types of a DataFrame
2110-
#' @family dataframe_funcs
2111-
#' @param x (DataFrame)
2112-
#' @return value (character) A character vector with the column types of the given DataFrame
2113-
#' @rdname coltypes
2114-
#' @examples \dontrun{
2115-
#' irisDF <- createDataFrame(sqlContext, iris)
2116-
#' coltypes(irisDF)
2117-
#' }
2118-
setMethod("coltypes",
2119-
signature(x = "DataFrame"),
2120-
function(x) {
2121-
# Get the data types of the DataFrame by invoking dtypes() function
2122-
types <- sapply(dtypes(x), function(x) {x[[2]]})
2123-
2124-
# Map Spark data types into R's data types using DATA_TYPES environment
2125-
rTypes <- sapply(types, USE.NAMES=F, FUN=function(x) {
2126-
2127-
# Check for primitive types
2128-
type <- PRIMITIVE_TYPES[[x]]
2129-
2130-
if (is.null(type)) {
2131-
# Check for complex types
2132-
for (t in names(COMPLEX_TYPES)) {
2133-
if (substring(x, 1, nchar(t)) == t) {
2134-
type <- COMPLEX_TYPES[[t]]
2135-
break
2136-
}
2137-
}
2138-
2139-
if (is.null(type)) {
2140-
stop(paste("Unsupported data type: ", x))
2141-
}
2142-
}
2143-
type
2144-
})
2145-
2146-
# Find which types don't have mapping to R
2147-
naIndices <- which(is.na(rTypes))
2148-
2149-
# Assign the original scala data types to the unmatched ones
2150-
rTypes[naIndices] <- types[naIndices]
2151-
2152-
rTypes
2153-
})

R/pkg/R/generics.R

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -385,6 +385,22 @@ setGeneric("agg", function (x, ...) { standardGeneric("agg") })
385385
#' @export
386386
setGeneric("arrange", function(x, col, ...) { standardGeneric("arrange") })
387387

388+
#' @rdname columns
389+
#' @export
390+
setGeneric("colnames", function(x, do.NULL = TRUE, prefix = "col") { standardGeneric("colnames") })
391+
392+
#' @rdname columns
393+
#' @export
394+
setGeneric("colnames<-", function(x, value) { standardGeneric("colnames<-") })
395+
396+
#' @rdname coltypes
397+
#' @export
398+
setGeneric("coltypes", function(x) { standardGeneric("coltypes") })
399+
400+
#' @rdname coltypes
401+
#' @export
402+
setGeneric("coltypes<-", function(x, value) { standardGeneric("coltypes<-") })
403+
388404
#' @rdname schema
389405
#' @export
390406
setGeneric("columns", function(x) {standardGeneric("columns") })
@@ -1081,7 +1097,3 @@ setGeneric("attach")
10811097
#' @rdname with
10821098
#' @export
10831099
setGeneric("with")
1084-
1085-
#' @rdname coltypes
1086-
#' @export
1087-
setGeneric("coltypes", function(x) { standardGeneric("coltypes") })

R/pkg/R/types.R

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,3 +41,11 @@ COMPLEX_TYPES <- list(
4141

4242
# The full list of data types.
4343
DATA_TYPES <- as.environment(c(as.list(PRIMITIVE_TYPES), COMPLEX_TYPES))
44+
45+
# An environment for mapping R to Scala, names are R types and values are Scala types.
46+
rToSQLTypes <- as.environment(list(
47+
"integer" = "integer", # in R, integer is 32bit
48+
"numeric" = "double", # in R, numeric == double which is 64bit
49+
"double" = "double",
50+
"character" = "string",
51+
"logical" = "boolean"))

R/pkg/inst/tests/test_sparkSQL.R

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -622,6 +622,26 @@ test_that("schema(), dtypes(), columns(), names() return the correct values/form
622622
expect_equal(testNames[2], "name")
623623
})
624624

625+
test_that("names() colnames() set the column names", {
626+
df <- jsonFile(sqlContext, jsonPath)
627+
names(df) <- c("col1", "col2")
628+
expect_equal(colnames(df)[2], "col2")
629+
630+
colnames(df) <- c("col3", "col4")
631+
expect_equal(names(df)[1], "col3")
632+
633+
# Test base::colnames base::names
634+
m2 <- cbind(1, 1:4)
635+
expect_equal(colnames(m2, do.NULL = FALSE), c("col1", "col2"))
636+
colnames(m2) <- c("x","Y")
637+
expect_equal(colnames(m2), c("x", "Y"))
638+
639+
z <- list(a = 1, b = "c", c = 1:3)
640+
expect_equal(names(z)[3], "c")
641+
names(z)[3] <- "c2"
642+
expect_equal(names(z)[3], "c2")
643+
})
644+
625645
test_that("head() and first() return the correct data", {
626646
df <- jsonFile(sqlContext, jsonPath)
627647
testHead <- head(df)
@@ -1617,7 +1637,7 @@ test_that("with() on a DataFrame", {
16171637
expect_equal(nrow(sum2), 35)
16181638
})
16191639

1620-
test_that("Method coltypes() to get R's data types of a DataFrame", {
1640+
test_that("Method coltypes() to get and set R's data types of a DataFrame", {
16211641
expect_equal(coltypes(irisDF), c(rep("numeric", 4), "character"))
16221642

16231643
data <- data.frame(c1=c(1,2,3),
@@ -1636,6 +1656,24 @@ test_that("Method coltypes() to get R's data types of a DataFrame", {
16361656
x <- createDataFrame(sqlContext, list(list(as.environment(
16371657
list("a"="b", "c"="d", "e"="f")))))
16381658
expect_equal(coltypes(x), "map<string,string>")
1659+
1660+
df <- selectExpr(jsonFile(sqlContext, jsonPath), "name", "(age * 1.21) as age")
1661+
expect_equal(dtypes(df), list(c("name", "string"), c("age", "decimal(24,2)")))
1662+
1663+
df1 <- select(df, cast(df$age, "integer"))
1664+
coltypes(df) <- c("character", "integer")
1665+
expect_equal(dtypes(df), list(c("name", "string"), c("age", "int")))
1666+
value <- collect(df[, 2])[[3, 1]]
1667+
expect_equal(value, collect(df1)[[3, 1]])
1668+
expect_equal(value, 22)
1669+
1670+
coltypes(df) <- c(NA, "numeric")
1671+
expect_equal(dtypes(df), list(c("name", "string"), c("age", "double")))
1672+
1673+
expect_error(coltypes(df) <- c("character"),
1674+
"Length of type vector should match the number of columns for DataFrame")
1675+
expect_error(coltypes(df) <- c("environment", "list"),
1676+
"Only atomic type is supported for column types")
16391677
})
16401678

16411679
unlink(parquetPath)

0 commit comments

Comments
 (0)