Skip to content

Commit f65300d

Browse files
Carl Hvarfnerfacebook-github-bot
authored andcommitted
Pulled out condition_on_observations-tests for easier testing (#2988)
Summary: Pulls out the existing condition_on_observations-tests for easier testing of the specific method. Does not need to land, but simplifies testing of condition_on_observations methods. Reviewed By: Balandat Differential Revision: D80805810
1 parent 9058a77 commit f65300d

File tree

4 files changed

+114
-66
lines changed

4 files changed

+114
-66
lines changed

test/models/test_fully_bayesian.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -792,7 +792,7 @@ def test_construct_inputs(self) -> None:
792792
else:
793793
self.assertTrue(Yvar.equal(data_dict["train_Yvar"]))
794794

795-
def test_condition_on_observation(self) -> None:
795+
def test_fbstgp_condition_on_observations(self) -> None:
796796
# The following conditioned data shapes should work (output describes):
797797
# training data shape after cond(batch shape in output is req. in gpytorch)
798798
# X: num_models x n x d, Y: num_models x n x d --> num_models x n x d

test/models/test_fully_bayesian_multitask.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -682,7 +682,7 @@ def test_acquisition_functions(self):
682682
)
683683
self.assertEqual(acqf(test_X).shape, torch.Size(batch_shape))
684684

685-
def test_condition_on_observation(self) -> None:
685+
def test_condition_on_observations(self) -> None:
686686
# The following conditioned data shapes should work (output describes):
687687
# training data shape after cond(batch shape in output is req. in gpytorch)
688688
# X: num_models x n x d, Y: num_models x n x d --> num_models x n x d

test/models/test_gpytorch.py

Lines changed: 70 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -155,12 +155,6 @@ def test_gpytorch_model(self):
155155
# test noise shape validation
156156
with self.assertRaises(BotorchTensorDimensionError):
157157
model.posterior(test_X, observation_noise=torch.rand(2, **tkwargs))
158-
# test conditioning on observations
159-
cm = model.condition_on_observations(
160-
torch.rand(2, 1, **tkwargs), torch.rand(2, 1, **tkwargs)
161-
)
162-
self.assertIsInstance(cm, SimpleGPyTorchModel)
163-
self.assertEqual(cm.train_targets.shape, torch.Size([7]))
164158
# test subset_output
165159
with self.assertRaises(NotImplementedError):
166160
model.subset_output([0])
@@ -255,20 +249,6 @@ def test_validate_tensor_args(self) -> None:
255249
):
256250
GPyTorchModel._validate_tensor_args(X, Y, Yvar, strict=strict)
257251

258-
def test_condition_on_observations_tensor_validation(self) -> None:
259-
model = SimpleGPyTorchModel(torch.rand(5, 1), torch.randn(5, 1))
260-
model.posterior(torch.rand(2, 1)) # evaluate the model to form caches.
261-
# Outside of fantasize, the inputs are validated.
262-
with self.assertWarnsRegex(
263-
BotorchTensorDimensionWarning, "Non-strict enforcement of"
264-
):
265-
model.condition_on_observations(torch.randn(2, 1), torch.randn(5, 2, 1))
266-
# Inside of fantasize, the inputs are not validated.
267-
with fantasize(), warnings.catch_warnings(record=True) as ws:
268-
warnings.filterwarnings("always", category=BotorchTensorDimensionWarning)
269-
model.condition_on_observations(torch.randn(2, 1), torch.randn(5, 2, 1))
270-
self.assertFalse(any(w.category is BotorchTensorDimensionWarning for w in ws))
271-
272252
def test_fantasize_flag(self):
273253
train_X = torch.rand(5, 1)
274254
train_Y = torch.sin(train_X)
@@ -358,12 +338,6 @@ def test_batched_multi_output_gpytorch_model(self):
358338
# test subset_output
359339
with self.assertRaises(NotImplementedError):
360340
model.subset_output([0])
361-
# test conditioning on observations
362-
cm = model.condition_on_observations(
363-
torch.rand(2, 1, **tkwargs), torch.rand(2, 2, **tkwargs)
364-
)
365-
self.assertIsInstance(cm, SimpleBatchedMultiOutputGPyTorchModel)
366-
self.assertEqual(cm.train_targets.shape, torch.Size([2, 7]))
367341
# test fantasize
368342
sampler = SobolQMCNormalSampler(sample_shape=torch.Size([2]))
369343
cm = model.fantasize(torch.rand(2, 1, **tkwargs), sampler=sampler)
@@ -402,6 +376,56 @@ def test_batched_multi_output_gpytorch_model(self):
402376
),
403377
)
404378

379+
def test_condition_on_observations(self):
380+
for dtype, use_octf in itertools.product(
381+
(torch.float, torch.double), (False, True)
382+
):
383+
tkwargs = {"device": self.device, "dtype": dtype}
384+
octf = Standardize(m=1) if use_octf else None
385+
train_X = torch.rand(5, 1, **tkwargs)
386+
train_Y = torch.sin(train_X)
387+
model = SimpleGPyTorchModel(train_X, train_Y, octf)
388+
389+
# must predict before conitioning on observations
390+
model.posterior(torch.rand(2, 1, **tkwargs))
391+
392+
# test conditioning on observations
393+
cm = model.condition_on_observations(
394+
torch.rand(2, 1, **tkwargs), torch.rand(2, 1, **tkwargs)
395+
)
396+
self.assertIsInstance(cm, SimpleGPyTorchModel)
397+
self.assertEqual(cm.train_targets.shape, torch.Size([7]))
398+
model = SimpleGPyTorchModel(torch.rand(5, 1), torch.randn(5, 1))
399+
400+
model.posterior(torch.rand(2, 1)) # evaluate the model to form caches.
401+
# Outside of fantasize, the inputs are validated.
402+
with self.assertWarnsRegex(
403+
BotorchTensorDimensionWarning, "Non-strict enforcement of"
404+
):
405+
model.condition_on_observations(torch.randn(2, 1), torch.randn(5, 2, 1))
406+
# Inside of fantasize, the inputs are not validated.
407+
with fantasize(), warnings.catch_warnings(record=True) as ws:
408+
warnings.filterwarnings("always", category=BotorchTensorDimensionWarning)
409+
model.condition_on_observations(torch.randn(2, 1), torch.randn(5, 2, 1))
410+
self.assertFalse(any(w.category is BotorchTensorDimensionWarning for w in ws))
411+
412+
def test_condition_on_observations_batched(self):
413+
for dtype in (torch.float, torch.double):
414+
tkwargs = {"device": self.device, "dtype": dtype}
415+
train_X = torch.rand(5, 1, **tkwargs)
416+
train_Y = torch.cat([torch.sin(train_X), torch.cos(train_X)], dim=-1)
417+
model = SimpleBatchedMultiOutputGPyTorchModel(train_X, train_Y)
418+
419+
# must predict before conitioning on observations
420+
model.posterior(torch.rand(2, 1, **tkwargs))
421+
422+
# test conditioning on observations
423+
cm = model.condition_on_observations(
424+
torch.rand(2, 1, **tkwargs), torch.rand(2, 2, **tkwargs)
425+
)
426+
self.assertIsInstance(cm, SimpleBatchedMultiOutputGPyTorchModel)
427+
self.assertEqual(cm.train_targets.shape, torch.Size([2, 7]))
428+
405429
def test_posterior_transform(self):
406430
tkwargs = {"device": self.device, "dtype": torch.double}
407431
train_X = torch.rand(5, 2, **tkwargs)
@@ -564,11 +588,6 @@ def test_model_list_gpytorch_model(self):
564588
)
565589
self.assertIsInstance(posterior, GPyTorchPosterior)
566590
self.assertEqual(posterior.mean.shape, torch.Size([2, 1]))
567-
# conditioning is not implemented (see ModelListGP for tests)
568-
with self.assertRaises(NotImplementedError):
569-
model.condition_on_observations(
570-
X=torch.rand(2, 1, **tkwargs), Y=torch.rand(2, 2, **tkwargs)
571-
)
572591

573592
def test_input_transform(self):
574593
# test that the input transforms are applied properly to individual models
@@ -651,3 +670,23 @@ def test_posterior_transform(self):
651670
post_tf = ScalarizedPosteriorTransform(torch.ones(2, **tkwargs))
652671
post = model.posterior(torch.rand(3, 1, **tkwargs), posterior_transform=post_tf)
653672
self.assertEqual(post.mean.shape, torch.Size([3, 1]))
673+
674+
def test_condition_on_observations_model_list(self):
675+
torch.manual_seed(12345)
676+
for dtype in (torch.float, torch.double):
677+
tkwargs = {"device": self.device, "dtype": dtype}
678+
train_X1, train_X2 = (
679+
torch.rand(5, 1, **tkwargs),
680+
torch.rand(5, 1, **tkwargs),
681+
)
682+
train_Y1 = torch.sin(train_X1)
683+
train_Y2 = torch.cos(train_X2)
684+
m1 = SimpleGPyTorchModel(train_X1, train_Y1)
685+
m2 = SimpleGPyTorchModel(train_X2, train_Y2)
686+
model = SimpleModelListGPyTorchModel(m1, m2)
687+
688+
# conditioning is not implemented (see ModelListGP for tests)
689+
with self.assertRaises(NotImplementedError):
690+
model.condition_on_observations(
691+
X=torch.rand(2, 1, **tkwargs), Y=torch.rand(2, 2, **tkwargs)
692+
)

test/models/test_model_list_gp_regression.py

Lines changed: 42 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -170,39 +170,6 @@ def _base_test_ModelListGP(
170170
if gpytorch_posterior_expected:
171171
self.assertIsInstance(posterior.distribution, MultivariateNormal)
172172

173-
# test condition_on_observations
174-
f_x = [torch.rand(2, 1, **tkwargs) for _ in range(2)]
175-
f_y = torch.rand(2, 2, **tkwargs)
176-
if fixed_noise:
177-
noise = 0.1 + 0.1 * torch.rand_like(f_y)
178-
cond_kwargs = {"noise": noise}
179-
else:
180-
cond_kwargs = {}
181-
cm = model.condition_on_observations(f_x, f_y, **cond_kwargs)
182-
self.assertIsInstance(cm, ModelListGP)
183-
184-
# test condition_on_observations batched
185-
f_x = [torch.rand(3, 2, 1, **tkwargs) for _ in range(2)]
186-
f_y = torch.rand(3, 2, 2, **tkwargs)
187-
cm = model.condition_on_observations(f_x, f_y, **cond_kwargs)
188-
self.assertIsInstance(cm, ModelListGP)
189-
190-
# test condition_on_observations batched (fast fantasies)
191-
f_x = [torch.rand(2, 1, **tkwargs) for _ in range(2)]
192-
f_y = torch.rand(3, 2, 2, **tkwargs)
193-
cm = model.condition_on_observations(f_x, f_y, **cond_kwargs)
194-
self.assertIsInstance(cm, ModelListGP)
195-
196-
# test condition_on_observations (incorrect input shape error)
197-
with self.assertRaises(BotorchTensorDimensionError):
198-
model.condition_on_observations(
199-
f_x, torch.rand(3, 2, 3, **tkwargs), **cond_kwargs
200-
)
201-
202-
# test X having wrong size
203-
with self.assertRaises(BotorchTensorDimensionError):
204-
model.condition_on_observations(f_x[:1], f_y)
205-
206173
# test posterior transform
207174
X = torch.rand(3, 1, **tkwargs)
208175
weights = torch.tensor([1, 2], **tkwargs)
@@ -218,6 +185,48 @@ def _base_test_ModelListGP(
218185

219186
return model
220187

188+
def test_condition_on_observations(self) -> None:
189+
for dtype, outcome_transform in itertools.product(
190+
(torch.float, torch.double), ("None", "Standardize", "Log", "Chained")
191+
):
192+
with self.subTest(dtype=dtype, outcome_transform=outcome_transform):
193+
tkwargs = {"device": self.device, "dtype": dtype}
194+
model = _get_model(
195+
fixed_noise=False, outcome_transform=outcome_transform, **tkwargs
196+
)
197+
# need to predict before conditioning
198+
test_x = torch.tensor([[0.25], [0.75]], **tkwargs)
199+
_ = model.posterior(test_x)
200+
201+
# test condition_on_observations
202+
f_x = [torch.rand(2, 1, **tkwargs) for _ in range(2)]
203+
f_y = torch.rand(2, 2, **tkwargs)
204+
cond_kwargs = {}
205+
cm = model.condition_on_observations(f_x, f_y, **cond_kwargs)
206+
self.assertIsInstance(cm, ModelListGP)
207+
208+
# test condition_on_observations batched
209+
f_x = [torch.rand(3, 2, 1, **tkwargs) for _ in range(2)]
210+
f_y = torch.rand(3, 2, 2, **tkwargs)
211+
cm = model.condition_on_observations(f_x, f_y, **cond_kwargs)
212+
self.assertIsInstance(cm, ModelListGP)
213+
214+
# test condition_on_observations batched (fast fantasies)
215+
f_x = [torch.rand(2, 1, **tkwargs) for _ in range(2)]
216+
f_y = torch.rand(3, 2, 2, **tkwargs)
217+
cm = model.condition_on_observations(f_x, f_y, **cond_kwargs)
218+
self.assertIsInstance(cm, ModelListGP)
219+
220+
# test condition_on_observations (incorrect input shape error)
221+
with self.assertRaises(BotorchTensorDimensionError):
222+
model.condition_on_observations(
223+
f_x, torch.rand(3, 2, 3, **tkwargs), **cond_kwargs
224+
)
225+
226+
# test X having wrong size
227+
with self.assertRaises(BotorchTensorDimensionError):
228+
model.condition_on_observations(f_x[:1], f_y)
229+
221230
def test_ModelListGP(self) -> None:
222231
for dtype, outcome_transform in itertools.product(
223232
(torch.float, torch.double), ("None", "Standardize", "Log", "Chained")

0 commit comments

Comments
 (0)