Skip to content

Commit 7567309

Browse files
authored
Merge pull request #1812 from cmu-delphi/ndefries/backfill/lp-solver-opt
[Backfill corrections] Force modeling to use GLPK if gurobi license info isn't provided
2 parents b44e22c + e3daff2 commit 7567309

File tree

11 files changed

+115
-24
lines changed

11 files changed

+115
-24
lines changed

backfill_corrections/delphiBackfillCorrection/DESCRIPTION

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ Suggests:
3232
knitr (>= 1.15),
3333
rmarkdown (>= 1.4),
3434
testthat (>= 1.0.1),
35-
covr (>= 2.2.2)
35+
covr (>= 2.2.2),
36+
mockr
3637
RoxygenNote: 7.2.0
3738
Encoding: UTF-8

backfill_corrections/delphiBackfillCorrection/R/model.R

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,10 @@ model_training_and_testing <- function(train_data, test_data, taus, covariates,
131131

132132
success = success + 1
133133
},
134-
error=function(e) {msg_ts("Training failed for ", model_path)}
134+
error=function(e) {
135+
msg_ts("Training failed for ", model_path, ". Check that your gurobi ",
136+
"license is valid and being passed properly to the program.")
137+
}
135138
)
136139
}
137140
if (success < length(taus)) {return (NULL)}

backfill_corrections/delphiBackfillCorrection/R/utils.R

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,37 @@ read_params <- function(path = "params.json", template_path = "params.json.templ
6767
# Model parameters
6868
if (!("taus" %in% names(params))) {params$taus <- TAUS}
6969
if (!("lambda" %in% names(params))) {params$lambda <- LAMBDA}
70-
if (!("lp_solver" %in% names(params))) {params$lp_solver <- LP_SOLVER}
7170
if (!("lag_pad" %in% names(params))) {params$lag_pad <- LAG_PAD}
7271

72+
if ("lp_solver" %in% names(params)) {
73+
params$lp_solver <- match.arg(params$lp_solver, c("gurobi", "glpk"))
74+
} else {
75+
params$lp_solver <- LP_SOLVER
76+
}
77+
if (params$lp_solver == "gurobi") {
78+
# Make call to gurobi CLI to check license. Returns a status of `0` if
79+
# license can be found and is valid.
80+
tryCatch(
81+
expr = {
82+
license_status <- run_cli("gurobi_cl")
83+
},
84+
error=function(e) {
85+
if (grepl("Error 10032: License has expired", e$message, fixed=TRUE)) {
86+
stop("The gurobi license has expired. Please renew or switch to ",
87+
"using glpk. lp_solver can be specified in params.json.")
88+
}
89+
msg_ts(e$message)
90+
license_status <- 1
91+
}
92+
)
93+
94+
if (license_status != 0) {
95+
warning("gurobi solver was requested but license information was ",
96+
"not available or not valid; using glpk instead")
97+
params$lp_solver <- "glpk"
98+
}
99+
}
100+
73101
# Data parameters
74102
if (!("num_col" %in% names(params))) {params$num_col <- "num"}
75103
if (!("denom_col" %in% names(params))) {params$denom_col <- "denom"}
@@ -109,6 +137,13 @@ read_params <- function(path = "params.json", template_path = "params.json.templ
109137
return(params)
110138
}
111139

140+
#' Wrapper for `base::system2` for testing convenience
141+
#'
142+
#' @param command string to run as command
143+
run_cli <- function(command) {
144+
system2(command)
145+
}
146+
112147
#' Create directory if not already existing
113148
#'
114149
#' @param path string specifying a directory to create

backfill_corrections/delphiBackfillCorrection/man/run_cli.Rd

Lines changed: 14 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"input_dir": "./test.temp",
3+
"lp_solver": "glpk"
4+
}
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"input_dir": "./test.temp",
3+
"lp_solver": "gurobi"
4+
}

backfill_corrections/delphiBackfillCorrection/unit-tests/testthat/params-run.json.template

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,6 @@
44
"ref_lag": 3,
55
"input_dir": "./input",
66
"export_dir": "./output",
7-
"cache_dir": "./cache"
7+
"cache_dir": "./cache",
8+
"lp_solver": "glpk"
89
}
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
{
2-
"input_dir": "./test.temp"
2+
"input_dir": "./test.temp",
3+
"lp_solver": "gurobi"
34
}

backfill_corrections/delphiBackfillCorrection/unit-tests/testthat/test-beta_prior_estimation.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ test_that("testing the squared error objection function given the beta prior", {
7070
test_that("testing the prior estimation", {
7171
dw <- "Sat_ref"
7272
priors <- est_priors(train_data, prior_test_data, geo, value_type, dw, TAUS,
73-
covariates, response, LP_SOLVER, lambda,
73+
covariates, response, "glpk", lambda,
7474
indicator, signal, geo_level, signal_suffix,
7575
training_end_date, training_start_date, model_save_dir)
7676
alpha <- priors[2]
@@ -110,7 +110,7 @@ test_that("testing the main beta prior adjustment function", {
110110
indicator, signal, geo_level, signal_suffix,
111111
lambda, value_type, geo,
112112
training_end_date, training_start_date, model_save_dir,
113-
taus = TAUS, lp_solver = LP_SOLVER)
113+
taus = TAUS, lp_solver = "glpk")
114114
updated_train_data <- updated_data[[1]]
115115
updated_test_data <- updated_data[[2]]
116116

backfill_corrections/delphiBackfillCorrection/unit-tests/testthat/test-model.R

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,21 +113,21 @@ test_that("testing generating or loading the model", {
113113

114114
# Generate the model and check again
115115
obj <- get_model(model_path, train_data, covariates, tau,
116-
lambda, LP_SOLVER, train_models=TRUE)
116+
lambda, "glpk", train_models=TRUE)
117117
expect_true(file.exists(model_path))
118118
created <- file.info(model_path)$ctime
119119

120120
# Check that the model was not generated again.
121121
obj <- get_model(model_path, train_data, covariates, tau,
122-
lambda, LP_SOLVER, train_models=FALSE)
122+
lambda, "glpk", train_models=FALSE)
123123
expect_equal(file.info(model_path)$ctime, created)
124124

125125
expect_silent(file.remove(model_path))
126126
})
127127

128128
test_that("testing model training and testing", {
129129
result <- model_training_and_testing(train_data, test_data, taus=TAUS, covariates=covariates,
130-
lp_solver=LP_SOLVER, lambda=lambda, test_lag=test_lag,
130+
lp_solver="glpk", lambda=lambda, test_lag=test_lag,
131131
geo=geo, value_type=value_type, model_save_dir=model_save_dir,
132132
indicator=indicator, signal=signal,
133133
geo_level=geo_level, signal_suffix=signal_suffix,

0 commit comments

Comments
 (0)