Skip to content

Commit bf7a89b

Browse files
Ethan Chefacebook-github-bot
authored andcommitted
log values for HVKG using LogEHVI
Summary: HKVG to output log values when _log=True using the LogEHVI value function. Also adds a _log flag to InverseCostWeightedUtility to output utilities in log-space. if _log=True, it assumes that * current_value is in log-space * cost_aware_utility outputs in log-space (raises an error if cost_aware_utility does not have a _log flag or if its _log=False). Note that InverseCostWeightedUtility does a logarithmic transform for the inputted costs; assumes that inputted costs are in the original space. This is so that one does not need to make any direct modification to the cost fn, and because InverseCostWeightedUtility already does some pre-processing to the costs (e.g. clipping). tldr: HKVG assumes all of its inputs are logged, but InverseCostWeightedUtility does not. Rollback Plan: Differential Revision: D80263869
1 parent a9f1f86 commit bf7a89b

File tree

3 files changed

+251
-97
lines changed

3 files changed

+251
-97
lines changed

botorch/acquisition/cost_aware.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from botorch.models.deterministic import DeterministicModel
2626
from botorch.models.gpytorch import GPyTorchModel
2727
from botorch.sampling.base import MCSampler
28+
from botorch.utils.safe_math import log as safe_log
2829
from pyre_extensions import none_throws
2930
from torch import Tensor
3031
from torch.nn import Module
@@ -113,6 +114,7 @@ def __init__(
113114
use_mean: bool = True,
114115
cost_objective: MCAcquisitionObjective | None = None,
115116
min_cost: float = 1e-2,
117+
log: bool = False,
116118
) -> None:
117119
r"""Cost-aware utility that weights increase in utility by inverse cost.
118120
For negative increases in utility, the utility is instead scaled by the
@@ -148,6 +150,7 @@ def __init__(
148150
self.cost_objective: MCAcquisitionObjective = cost_objective
149151
self._use_mean = use_mean
150152
self._min_cost = min_cost
153+
self._log = log
151154

152155
def forward(
153156
self,
@@ -215,5 +218,10 @@ def forward(
215218

216219
# compute and return the ratio on the sample level - If `use_mean=True`
217220
# this operation involves broadcasting the cost across fantasies.
218-
# We multiply by the cost if the deltas are <= 0, see discussion #2914
219-
return torch.where(deltas > 0, deltas / cost, deltas * cost)
221+
if self._log:
222+
# if _log is True then input deltas are in log space
223+
# so original deltas cannot be <= 0
224+
return deltas - safe_log(cost)
225+
else:
226+
# We multiply by the cost if the deltas are <= 0, see discussion #2914
227+
return torch.where(deltas > 0, deltas / cost, deltas * cost)

botorch/acquisition/multi_objective/hypervolume_knowledge_gradient.py

Lines changed: 68 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,12 @@
3232
from botorch.acquisition.decoupled import DecoupledAcquisitionFunction
3333
from botorch.acquisition.knowledge_gradient import ProjectedAcquisitionFunction
3434
from botorch.acquisition.multi_objective.base import MultiObjectiveMCAcquisitionFunction
35+
from botorch.acquisition.multi_objective.logei import qLogExpectedHypervolumeImprovement
3536
from botorch.acquisition.multi_objective.monte_carlo import (
3637
qExpectedHypervolumeImprovement,
3738
)
3839
from botorch.acquisition.multi_objective.objective import MCMultiOutputObjective
39-
from botorch.exceptions.errors import UnsupportedError
40+
from botorch.exceptions.errors import BotorchError, UnsupportedError
4041
from botorch.exceptions.warnings import NumericsWarning
4142
from botorch.models.deterministic import PosteriorMeanModel
4243
from botorch.models.model import Model
@@ -47,6 +48,7 @@
4748
from botorch.utils.multi_objective.box_decompositions.non_dominated import (
4849
FastNondominatedPartitioning,
4950
)
51+
from botorch.utils.safe_math import logdiffexp, logmeanexp
5052
from botorch.utils.transforms import (
5153
average_over_ensemble_models,
5254
match_batch_shape,
@@ -91,6 +93,7 @@ def __init__(
9193
current_value: Tensor | None = None,
9294
use_posterior_mean: bool = True,
9395
cost_aware_utility: CostAwareUtility | None = None,
96+
log: bool = False,
9497
) -> None:
9598
r"""q-Hypervolume Knowledge Gradient.
9699
@@ -133,6 +136,9 @@ def __init__(
133136
[Daulton2023hvkg]_ for details.
134137
cost_aware_utility: A CostAwareUtility specifying the cost function for
135138
evaluating the `X` on the objectives indicated by `evaluation_mask`.
139+
log: If True, then returns the log of the HVKG value. If True, then it
140+
expects current_value to be in log-space and cost_aware_utility to
141+
output log utilities.
136142
"""
137143
if sampler is None:
138144
# base samples should be fixed for joint optimization over X, X_fantasies
@@ -170,6 +176,8 @@ def __init__(
170176
self.cost_aware_utility = cost_aware_utility
171177
self._cost_sampler = None
172178

179+
self._log = log
180+
173181
@property
174182
def cost_sampler(self):
175183
if self._cost_sampler is None:
@@ -242,6 +250,7 @@ def forward(self, X: Tensor) -> Tensor:
242250
objective=self.objective,
243251
sampler=self.inner_sampler,
244252
use_posterior_mean=self.use_posterior_mean,
253+
log=self._log,
245254
)
246255

247256
# make sure to propagate gradients to the fantasy model train inputs
@@ -259,9 +268,23 @@ def forward(self, X: Tensor) -> Tensor:
259268
values = value_function(X=X_fantasies.reshape(shape)) # num_fantasies x b
260269

261270
if self.current_value is not None:
262-
values = values - self.current_value
271+
if self._log:
272+
values = logdiffexp(self.current_value, values)
273+
else:
274+
values = values - self.current_value
263275

264276
if self.cost_aware_utility is not None:
277+
if self._log:
278+
# check whether cost_aware_utility has a _log flag
279+
# raises an error if it does not or if _log is False
280+
if (
281+
not hasattr(self.cost_aware_utility, "_log")
282+
or not self.cost_aware_utility._log
283+
):
284+
raise BotorchError(
285+
"Cost-aware HVKG has _log=True and requires cost_aware_utility"
286+
"to output log utilities."
287+
)
265288
values = self.cost_aware_utility(
266289
# exclude pending points
267290
X=X_actual[..., :q, :],
@@ -271,7 +294,10 @@ def forward(self, X: Tensor) -> Tensor:
271294
)
272295

273296
# return average over the fantasy samples
274-
return values.mean(dim=0)
297+
if self._log:
298+
return logmeanexp(values, dim=0)
299+
else:
300+
return values.mean(dim=0)
275301

276302
def get_augmented_q_batch_size(self, q: int) -> int:
277303
r"""Get augmented q batch size for one-shot optimization.
@@ -329,6 +355,7 @@ def __init__(
329355
valfunc_cls: type[AcquisitionFunction] | None = None,
330356
valfunc_argfac: Callable[[Model], dict[str, Any]] | None = None,
331357
use_posterior_mean: bool = True,
358+
log: bool = False,
332359
**kwargs: Any,
333360
) -> None:
334361
r"""Multi-Fidelity q-Knowledge Gradient (one-shot optimization).
@@ -376,6 +403,7 @@ def __init__(
376403
valfunc_argfac: An argument factory, i.e. callable that maps a `Model`
377404
to a dictionary of kwargs for the terminal value function (e.g.
378405
`best_f` for `ExpectedImprovement`).
406+
log: If True, then returns the log of the HVKG value.
379407
"""
380408

381409
super().__init__(
@@ -392,6 +420,7 @@ def __init__(
392420
current_value=current_value,
393421
use_posterior_mean=use_posterior_mean,
394422
cost_aware_utility=cost_aware_utility,
423+
log=log,
395424
)
396425
self.project = project
397426
if kwargs.get("expand") is not None:
@@ -465,6 +494,7 @@ def forward(self, X: Tensor) -> Tensor:
465494
valfunc_cls=self.valfunc_cls,
466495
valfunc_argfac=self.valfunc_argfac,
467496
use_posterior_mean=self.use_posterior_mean,
497+
log=self._log,
468498
)
469499

470500
# make sure to propagate gradients to the fantasy model train inputs
@@ -481,9 +511,24 @@ def forward(self, X: Tensor) -> Tensor:
481511
)
482512
values = value_function(X=X_fantasies.reshape(shape)) # num_fantasies x b
483513
if self.current_value is not None:
484-
values = values - self.current_value
514+
if self._log:
515+
# Assumes current value is in log-space
516+
values = logdiffexp(self.current_value, values)
517+
else:
518+
values = values - self.current_value
485519

486520
if self.cost_aware_utility is not None:
521+
if self._log:
522+
# check whether cost_aware_utility has a _log flag
523+
# raises an error if it does not or if _log is False
524+
if (
525+
not hasattr(self.cost_aware_utility, "_log")
526+
or not self.cost_aware_utility._log
527+
):
528+
raise BotorchError(
529+
"Cost-aware HVKG has _log=True and requires cost_aware_utility"
530+
"to output log utilities."
531+
)
487532
values = self.cost_aware_utility(
488533
# exclude pending points
489534
X=X_actual[..., :q, :],
@@ -493,7 +538,10 @@ def forward(self, X: Tensor) -> Tensor:
493538
)
494539

495540
# return average over the fantasy samples
496-
return values.mean(dim=0)
541+
if self._log:
542+
return logmeanexp(values, dim=0)
543+
else:
544+
return values.mean(dim=0)
497545

498546

499547
def _get_hv_value_function(
@@ -505,6 +553,7 @@ def _get_hv_value_function(
505553
valfunc_cls: type[AcquisitionFunction] | None = None,
506554
valfunc_argfac: Callable[[Model], dict[str, Any]] | None = None,
507555
use_posterior_mean: bool = False,
556+
log: bool = False,
508557
) -> AcquisitionFunction:
509558
r"""Construct value function (i.e. inner acquisition function).
510559
This is a method for computing hypervolume.
@@ -518,20 +567,27 @@ def _get_hv_value_function(
518567
action="ignore",
519568
category=NumericsWarning,
520569
)
521-
base_value_function = qExpectedHypervolumeImprovement(
522-
model=model,
523-
ref_point=ref_point,
524-
partitioning=FastNondominatedPartitioning(
570+
571+
value_fn_kwargs = {
572+
"model": model,
573+
"ref_point": ref_point,
574+
"partitioning": FastNondominatedPartitioning(
525575
ref_point=ref_point,
526576
Y=torch.empty(
527577
(0, ref_point.shape[0]),
528578
dtype=ref_point.dtype,
529579
device=ref_point.device,
530580
),
531581
), # create empty partitioning
532-
sampler=sampler,
533-
objective=objective,
534-
)
582+
"sampler": sampler,
583+
"objective": objective,
584+
}
585+
586+
if log:
587+
base_value_function = qLogExpectedHypervolumeImprovement(**value_fn_kwargs)
588+
else:
589+
base_value_function = qExpectedHypervolumeImprovement(**value_fn_kwargs)
590+
535591
# ProjectedAcquisitionFunction requires this
536592
base_value_function.posterior_transform = None
537593

0 commit comments

Comments
 (0)