Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion botorch/models/gpytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,8 +242,8 @@ def condition_on_observations(
>>> new_Y = torch.sin(new_X[:, :1]) + torch.cos(new_X[:, 1:])
>>> model = model.condition_on_observations(X=new_X, Y=new_Y)
"""
X = self.transform_inputs(X)
Yvar = noise

if hasattr(self, "outcome_transform"):
# pass the transformed data to get_fantasy_model below
# (unless we've already trasnformed if BatchedMultiOutputGPyTorchModel)
Expand Down
6 changes: 2 additions & 4 deletions botorch/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ def fantasize(
if observation_noise is not None:
kwargs["noise"] = observation_noise.expand(Y.shape[1:])
return self.condition_on_observations(
X=self.transform_inputs(X),
X=X,
Y=Y,
**kwargs,
)
Expand All @@ -395,9 +395,7 @@ def fantasize(
Y_fantasized = sampler(post_X) # num_fantasies x batch_shape x n' x m
if observation_noise is not None:
kwargs["noise"] = observation_noise.expand(Y_fantasized.shape[1:])
return self.condition_on_observations(
X=self.transform_inputs(X), Y=Y_fantasized, **kwargs
)
return self.condition_on_observations(X=X, Y=Y_fantasized, **kwargs)


class ModelList(Model):
Expand Down
44 changes: 44 additions & 0 deletions test/models/test_gpytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -690,3 +690,47 @@ def test_condition_on_observations_model_list(self):
model.condition_on_observations(
X=torch.rand(2, 1, **tkwargs), Y=torch.rand(2, 2, **tkwargs)
)

def test_condition_on_observations_input_transform_consistency(self):
"""Test that input transforms are applied consistently in
condition_on_observations.

This addresses https://github.com/pytorch/botorch/issues/2533:
inputs should be transformed when conditioning.
"""
for dtype in (torch.float, torch.double):
tkwargs = {"device": self.device, "dtype": dtype}

# Create model with input transform
train_X = torch.tensor([[0.0], [0.5], [1.0]], **tkwargs)
train_Y = torch.tensor([[1.0], [2.0], [3.0]], **tkwargs)

input_transform = SimpleInputTransform(transform_on_train=True)
model = SimpleGPyTorchModel(
train_X, train_Y, input_transform=input_transform
)

# Condition on new observations
new_X = torch.tensor([[0.25], [0.75]], **tkwargs)
new_Y = torch.tensor([[1.5], [2.5]], **tkwargs)

# Get original train_inputs for comparison
model.eval() # Put in eval mode to see transformed inputs
original_transformed_inputs = model.train_inputs[0].clone()
_ = model.posterior(train_X)
# Condition on observations
conditioned_model = model.condition_on_observations(new_X, new_Y)
conditioned_model.eval() # Put in eval mode to see transformed inputs

# Check that new inputs were transformed before being added
expected_transformed_new_X = input_transform(new_X)
expected_combined_inputs = torch.cat(
[original_transformed_inputs, expected_transformed_new_X], dim=0
)

# NOTE This would not have passed before - the last two inputs
# (corresponding to new_X) would not have been transformed.
self.assertAllClose(
conditioned_model.train_inputs[0],
expected_combined_inputs,
)
Loading