Skip to content

Commit bdfc0c5

Browse files
Carl Hvarfnerfacebook-github-bot
authored andcommitted
condition_on_observations does not apply input transforms (#2989)
Summary: Pull Request resolved: #2989 Fixes OSS-reported issues in the interaction between condition_on_observations and input transforms, as input transforms are generally bypassed when using the method. Why bother? I personally seem to use this method for just about all my research, including the initialization work currently under review. Moreover, it can't be trialed in Ax without this fix. Differential Revision: D80813693
1 parent b8bebe3 commit bdfc0c5

File tree

3 files changed

+47
-5
lines changed

3 files changed

+47
-5
lines changed

botorch/models/gpytorch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,8 +242,8 @@ def condition_on_observations(
242242
>>> new_Y = torch.sin(new_X[:, :1]) + torch.cos(new_X[:, 1:])
243243
>>> model = model.condition_on_observations(X=new_X, Y=new_Y)
244244
"""
245+
X = self.transform_inputs(X)
245246
Yvar = noise
246-
247247
if hasattr(self, "outcome_transform"):
248248
# pass the transformed data to get_fantasy_model below
249249
# (unless we've already trasnformed if BatchedMultiOutputGPyTorchModel)

botorch/models/model.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,7 @@ def fantasize(
379379
if observation_noise is not None:
380380
kwargs["noise"] = observation_noise.expand(Y.shape[1:])
381381
return self.condition_on_observations(
382-
X=self.transform_inputs(X),
382+
X=X,
383383
Y=Y,
384384
**kwargs,
385385
)
@@ -395,9 +395,7 @@ def fantasize(
395395
Y_fantasized = sampler(post_X) # num_fantasies x batch_shape x n' x m
396396
if observation_noise is not None:
397397
kwargs["noise"] = observation_noise.expand(Y_fantasized.shape[1:])
398-
return self.condition_on_observations(
399-
X=self.transform_inputs(X), Y=Y_fantasized, **kwargs
400-
)
398+
return self.condition_on_observations(X=X, Y=Y_fantasized, **kwargs)
401399

402400

403401
class ModelList(Model):

test/models/test_gpytorch.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -690,3 +690,47 @@ def test_condition_on_observations_model_list(self):
690690
model.condition_on_observations(
691691
X=torch.rand(2, 1, **tkwargs), Y=torch.rand(2, 2, **tkwargs)
692692
)
693+
694+
def test_condition_on_observations_input_transform_consistency(self):
695+
"""Test that input transforms are applied consistently in
696+
condition_on_observations.
697+
698+
This addresses https://github.com/pytorch/botorch/issues/2533:
699+
inputs should be transformed when conditioning.
700+
"""
701+
for dtype in (torch.float, torch.double):
702+
tkwargs = {"device": self.device, "dtype": dtype}
703+
704+
# Create model with input transform
705+
train_X = torch.tensor([[0.0], [0.5], [1.0]], **tkwargs)
706+
train_Y = torch.tensor([[1.0], [2.0], [3.0]], **tkwargs)
707+
708+
input_transform = SimpleInputTransform(transform_on_train=True)
709+
model = SimpleGPyTorchModel(
710+
train_X, train_Y, input_transform=input_transform
711+
)
712+
713+
# Condition on new observations
714+
new_X = torch.tensor([[0.25], [0.75]], **tkwargs)
715+
new_Y = torch.tensor([[1.5], [2.5]], **tkwargs)
716+
717+
# Get original train_inputs for comparison
718+
model.eval() # Put in eval mode to see transformed inputs
719+
original_transformed_inputs = model.train_inputs[0].clone()
720+
_ = model.posterior(train_X)
721+
# Condition on observations
722+
conditioned_model = model.condition_on_observations(new_X, new_Y)
723+
conditioned_model.eval() # Put in eval mode to see transformed inputs
724+
725+
# Check that new inputs were transformed before being added
726+
expected_transformed_new_X = input_transform(new_X)
727+
expected_combined_inputs = torch.cat(
728+
[original_transformed_inputs, expected_transformed_new_X], dim=0
729+
)
730+
731+
# NOTE This would not have passed before - the last two inputs
732+
# (corresponding to new_X) would not have been transformed.
733+
self.assertAllClose(
734+
conditioned_model.train_inputs[0],
735+
expected_combined_inputs,
736+
)

0 commit comments

Comments
 (0)