-
-
Notifications
You must be signed in to change notification settings - Fork 8.8k
Description
Hi everyone,
I am wondering whether it is possible to create custom multiclass objectives for xgboost in R.
Below is an MWE:
rm(list=ls())
require(xgboost)
set.seed(0)
# Simulate some dummy data
n <- 1e2
num_features <- 2
num_class <- 3
y <- apply(rmultinom(n, 1, rep(1, num_class)), 2, function(yy) which(yy != 0))
x <- matrix(rnorm(n*num_features), n, num_features)
# Format for xgboost
dtrain <- xgb.DMatrix(x, label = y)
# A dummy custom objective to understand what is going on
dummy_obj <- function(preds, dtrain) {
browser()
}
# Feed to xgboost
params_dummy <- list(objective = dummy_obj,
num_class = num_class)
model_dummy <- xgb.train(params_dummy, data = dtrain, nrounds = 1)
On my computer, executing the code above allowed me to do:
Called from: obj(pred, dtrain)
Browse[1]> preds
[1] 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
[40] 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
[79] 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
Browse[1]> Q
Having read the source code of multiclass_obj.cc, I was expecting for preds a vector of length n x num_class, or a matrix with n rows and num_class columns, or something like that.
So, my question is: am I missing something obvious (very likely) or is it currently not possible to implement custom multiclass objectives?
Below are my environment informations, although I doubt they will be useful in this case.
Operating System:
OS X 10.12.6
Compiler:
Not relevant
Package used (python/R/jvm/C++):
R
sessionInfo():
R version 3.4.1 (2017-06-30)
Platform: x86_64-apple-darwin16.6.0 (64-bit)
Running under: macOS Sierra 10.12.6
Matrix products: default
BLAS: /System/Library/Frameworks/Accelerate.framework/Versions/A/Frameworks/vecLib.framework/Versions/A/libBLAS.dylib
LAPACK: /System/Library/Frameworks/Accelerate.framework/Versions/A/Frameworks/vecLib.framework/Versions/A/libLAPACK.dylib
locale:
[1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
attached base packages:
[1] stats graphics grDevices utils datasets methods base
other attached packages:
[1] xgboost_0.6-4 tvmisc_0.1.0
loaded via a namespace (and not attached):
[1] compiler_3.4.1 magrittr_1.5 Matrix_1.2-11 tools_3.4.1 rstudioapi_0.6
[6] stringi_1.1.5 grid_3.4.1 data.table_1.10.4 lattice_0.20-35