@@ -665,6 +665,8 @@ setMethod("predict", signature(object = "KMeansModel"),
665665# ' @param tol convergence tolerance of iterations.
666666# ' @param stepSize stepSize parameter.
667667# ' @param seed seed parameter for weights initialization.
668+ # ' @param initialWeights initialWeights parameter for weights initialization, it should be a
669+ # ' numeric vector.
668670# ' @param ... additional arguments passed to the method.
669671# ' @return \code{spark.mlp} returns a fitted Multilayer Perceptron Classification Model.
670672# ' @rdname spark.mlp
@@ -677,8 +679,9 @@ setMethod("predict", signature(object = "KMeansModel"),
677679# ' df <- read.df("data/mllib/sample_multiclass_classification_data.txt", source = "libsvm")
678680# '
679681# ' # fit a Multilayer Perceptron Classification Model
680- # ' model <- spark.mlp(df, blockSize = 128, layers = c(4, 5, 4, 3), solver = "l-bfgs",
681- # ' maxIter = 100, tol = 0.5, stepSize = 1, seed = 1)
682+ # ' model <- spark.mlp(df, blockSize = 128, layers = c(4, 3), solver = "l-bfgs",
683+ # ' maxIter = 100, tol = 0.5, stepSize = 1, seed = 1,
684+ # ' initialWeights = c(0, 0, 0, 0, 0, 5, 5, 5, 5, 5, 9, 9, 9, 9, 9))
682685# '
683686# ' # get the summary of the model
684687# ' summary(model)
@@ -695,7 +698,7 @@ setMethod("predict", signature(object = "KMeansModel"),
695698# ' @note spark.mlp since 2.1.0
696699setMethod ("spark.mlp ", signature(data = "SparkDataFrame"),
697700 function (data , layers , blockSize = 128 , solver = " l-bfgs" , maxIter = 100 ,
698- tol = 1E-6 , stepSize = 0.03 , seed = NULL ) {
701+ tol = 1E-6 , stepSize = 0.03 , seed = NULL , initialWeights = NULL ) {
699702 if (is.null(layers )) {
700703 stop (" layers must be a integer vector with length > 1." )
701704 }
@@ -706,10 +709,13 @@ setMethod("spark.mlp", signature(data = "SparkDataFrame"),
706709 if (! is.null(seed )) {
707710 seed <- as.character(as.integer(seed ))
708711 }
712+ if (! is.null(initialWeights )) {
713+ initialWeights <- as.array(as.numeric(na.omit(initialWeights )))
714+ }
709715 jobj <- callJStatic(" org.apache.spark.ml.r.MultilayerPerceptronClassifierWrapper" ,
710716 " fit" , data @ sdf , as.integer(blockSize ), as.array(layers ),
711717 as.character(solver ), as.integer(maxIter ), as.numeric(tol ),
712- as.numeric(stepSize ), seed )
718+ as.numeric(stepSize ), seed , initialWeights )
713719 new(" MultilayerPerceptronClassificationModel" , jobj = jobj )
714720 })
715721
0 commit comments