Skip to content

Commit 12a41bb

Browse files
committed
fix tests
1 parent 49f36f3 commit 12a41bb

File tree

3 files changed

+5
-6
lines changed

3 files changed

+5
-6
lines changed

R/pkg/R/mllib.R

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
#' @export
2323
setClass("PipelineModel", representation(model = "jobj"))
2424

25-
#' @tile S4 class that represents a NaiveBayesModel
25+
#' @title S4 class that represents a NaiveBayesModel
2626
#' @param jobj a Java object reference to the backing Scala NaiveBayesWrapper
2727
#' @export
2828
setClass("NaiveBayesModel", representation(jobj = "jobj"))
@@ -66,7 +66,7 @@ setMethod("glm", signature(formula = "formula", family = "ANY", data = "DataFram
6666
return(new("PipelineModel", model = model))
6767
})
6868

69-
#' Make predictions from a amodel
69+
#' Make predictions from a model
7070
#'
7171
#' Makes predictions from a model produced by glm(), similarly to R's predict().
7272
#'
@@ -268,7 +268,6 @@ setMethod("fitted", signature(object = "PipelineModel"),
268268
#'}
269269
setMethod("naiveBayes", signature(formula = "formula", data = "DataFrame"),
270270
function(formula, data, laplace = 0, ...) {
271-
data <- na.omit(data)
272271
formula <- paste(deparse(formula), collapse = "")
273272
jobj <- callJStatic("org.apache.spark.ml.r.NaiveBayesWrapper", "fit",
274273
formula, data@sdf, laplace)

R/pkg/inst/tests/testthat/test_mllib.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -186,9 +186,9 @@ test_that("naiveBayes", {
186186
df <- suppressWarnings(createDataFrame(sqlContext, t1))
187187
m <- naiveBayes(Survived ~ ., data = df)
188188
s <- summary(m)
189-
expect_equal(s$apriori[1, "Yes"], 0.5833333, tolerance = 1e-6)
189+
expect_equal(as.double(s$apriori[1, "Yes"]), 0.5833333, tolerance = 1e-6)
190190
expect_equal(sum(s$apriori), 1)
191-
expect_equal(s$tables["Yes", "Age_Adult"], 0.5714286, tolerance = 1e-6)
191+
expect_equal(as.double(s$tables["Yes", "Age_Adult"]), 0.5714286, tolerance = 1e-6)
192192
p <- collect(select(predict(m, df), "prediction"))
193193
expect_equal(p$prediction, c("Yes", "Yes", "Yes", "Yes", "No", "No", "Yes", "Yes", "No", "No",
194194
"Yes", "Yes", "Yes", "Yes", "Yes", "Yes", "Yes", "Yes", "No", "No",

mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
package org.apache.spark.ml.r
1919

2020
import org.apache.spark.ml.{Pipeline, PipelineModel}
21-
import org.apache.spark.ml.attribute.{AttributeGroup, Attribute, NominalAttribute}
21+
import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NominalAttribute}
2222
import org.apache.spark.ml.classification.{NaiveBayes, NaiveBayesModel}
2323
import org.apache.spark.ml.feature.{IndexToString, RFormula}
2424
import org.apache.spark.sql.DataFrame

0 commit comments

Comments
 (0)