@@ -155,12 +155,6 @@ def test_gpytorch_model(self):
155
155
# test noise shape validation
156
156
with self .assertRaises (BotorchTensorDimensionError ):
157
157
model .posterior (test_X , observation_noise = torch .rand (2 , ** tkwargs ))
158
- # test conditioning on observations
159
- cm = model .condition_on_observations (
160
- torch .rand (2 , 1 , ** tkwargs ), torch .rand (2 , 1 , ** tkwargs )
161
- )
162
- self .assertIsInstance (cm , SimpleGPyTorchModel )
163
- self .assertEqual (cm .train_targets .shape , torch .Size ([7 ]))
164
158
# test subset_output
165
159
with self .assertRaises (NotImplementedError ):
166
160
model .subset_output ([0 ])
@@ -255,20 +249,6 @@ def test_validate_tensor_args(self) -> None:
255
249
):
256
250
GPyTorchModel ._validate_tensor_args (X , Y , Yvar , strict = strict )
257
251
258
- def test_condition_on_observations_tensor_validation (self ) -> None :
259
- model = SimpleGPyTorchModel (torch .rand (5 , 1 ), torch .randn (5 , 1 ))
260
- model .posterior (torch .rand (2 , 1 )) # evaluate the model to form caches.
261
- # Outside of fantasize, the inputs are validated.
262
- with self .assertWarnsRegex (
263
- BotorchTensorDimensionWarning , "Non-strict enforcement of"
264
- ):
265
- model .condition_on_observations (torch .randn (2 , 1 ), torch .randn (5 , 2 , 1 ))
266
- # Inside of fantasize, the inputs are not validated.
267
- with fantasize (), warnings .catch_warnings (record = True ) as ws :
268
- warnings .filterwarnings ("always" , category = BotorchTensorDimensionWarning )
269
- model .condition_on_observations (torch .randn (2 , 1 ), torch .randn (5 , 2 , 1 ))
270
- self .assertFalse (any (w .category is BotorchTensorDimensionWarning for w in ws ))
271
-
272
252
def test_fantasize_flag (self ):
273
253
train_X = torch .rand (5 , 1 )
274
254
train_Y = torch .sin (train_X )
@@ -358,12 +338,6 @@ def test_batched_multi_output_gpytorch_model(self):
358
338
# test subset_output
359
339
with self .assertRaises (NotImplementedError ):
360
340
model .subset_output ([0 ])
361
- # test conditioning on observations
362
- cm = model .condition_on_observations (
363
- torch .rand (2 , 1 , ** tkwargs ), torch .rand (2 , 2 , ** tkwargs )
364
- )
365
- self .assertIsInstance (cm , SimpleBatchedMultiOutputGPyTorchModel )
366
- self .assertEqual (cm .train_targets .shape , torch .Size ([2 , 7 ]))
367
341
# test fantasize
368
342
sampler = SobolQMCNormalSampler (sample_shape = torch .Size ([2 ]))
369
343
cm = model .fantasize (torch .rand (2 , 1 , ** tkwargs ), sampler = sampler )
@@ -402,6 +376,56 @@ def test_batched_multi_output_gpytorch_model(self):
402
376
),
403
377
)
404
378
379
+ def test_condition_on_observations (self ):
380
+ for dtype , use_octf in itertools .product (
381
+ (torch .float , torch .double ), (False , True )
382
+ ):
383
+ tkwargs = {"device" : self .device , "dtype" : dtype }
384
+ octf = Standardize (m = 1 ) if use_octf else None
385
+ train_X = torch .rand (5 , 1 , ** tkwargs )
386
+ train_Y = torch .sin (train_X )
387
+ model = SimpleGPyTorchModel (train_X , train_Y , octf )
388
+
389
+ # must predict before conitioning on observations
390
+ model .posterior (torch .rand (2 , 1 , ** tkwargs ))
391
+
392
+ # test conditioning on observations
393
+ cm = model .condition_on_observations (
394
+ torch .rand (2 , 1 , ** tkwargs ), torch .rand (2 , 1 , ** tkwargs )
395
+ )
396
+ self .assertIsInstance (cm , SimpleGPyTorchModel )
397
+ self .assertEqual (cm .train_targets .shape , torch .Size ([7 ]))
398
+ model = SimpleGPyTorchModel (torch .rand (5 , 1 ), torch .randn (5 , 1 ))
399
+
400
+ model .posterior (torch .rand (2 , 1 )) # evaluate the model to form caches.
401
+ # Outside of fantasize, the inputs are validated.
402
+ with self .assertWarnsRegex (
403
+ BotorchTensorDimensionWarning , "Non-strict enforcement of"
404
+ ):
405
+ model .condition_on_observations (torch .randn (2 , 1 ), torch .randn (5 , 2 , 1 ))
406
+ # Inside of fantasize, the inputs are not validated.
407
+ with fantasize (), warnings .catch_warnings (record = True ) as ws :
408
+ warnings .filterwarnings ("always" , category = BotorchTensorDimensionWarning )
409
+ model .condition_on_observations (torch .randn (2 , 1 ), torch .randn (5 , 2 , 1 ))
410
+ self .assertFalse (any (w .category is BotorchTensorDimensionWarning for w in ws ))
411
+
412
+ def test_condition_on_observations_batched (self ):
413
+ for dtype in (torch .float , torch .double ):
414
+ tkwargs = {"device" : self .device , "dtype" : dtype }
415
+ train_X = torch .rand (5 , 1 , ** tkwargs )
416
+ train_Y = torch .cat ([torch .sin (train_X ), torch .cos (train_X )], dim = - 1 )
417
+ model = SimpleBatchedMultiOutputGPyTorchModel (train_X , train_Y )
418
+
419
+ # must predict before conitioning on observations
420
+ model .posterior (torch .rand (2 , 1 , ** tkwargs ))
421
+
422
+ # test conditioning on observations
423
+ cm = model .condition_on_observations (
424
+ torch .rand (2 , 1 , ** tkwargs ), torch .rand (2 , 2 , ** tkwargs )
425
+ )
426
+ self .assertIsInstance (cm , SimpleBatchedMultiOutputGPyTorchModel )
427
+ self .assertEqual (cm .train_targets .shape , torch .Size ([2 , 7 ]))
428
+
405
429
def test_posterior_transform (self ):
406
430
tkwargs = {"device" : self .device , "dtype" : torch .double }
407
431
train_X = torch .rand (5 , 2 , ** tkwargs )
@@ -564,11 +588,6 @@ def test_model_list_gpytorch_model(self):
564
588
)
565
589
self .assertIsInstance (posterior , GPyTorchPosterior )
566
590
self .assertEqual (posterior .mean .shape , torch .Size ([2 , 1 ]))
567
- # conditioning is not implemented (see ModelListGP for tests)
568
- with self .assertRaises (NotImplementedError ):
569
- model .condition_on_observations (
570
- X = torch .rand (2 , 1 , ** tkwargs ), Y = torch .rand (2 , 2 , ** tkwargs )
571
- )
572
591
573
592
def test_input_transform (self ):
574
593
# test that the input transforms are applied properly to individual models
@@ -651,3 +670,23 @@ def test_posterior_transform(self):
651
670
post_tf = ScalarizedPosteriorTransform (torch .ones (2 , ** tkwargs ))
652
671
post = model .posterior (torch .rand (3 , 1 , ** tkwargs ), posterior_transform = post_tf )
653
672
self .assertEqual (post .mean .shape , torch .Size ([3 , 1 ]))
673
+
674
+ def test_condition_on_observations_model_list (self ):
675
+ torch .manual_seed (12345 )
676
+ for dtype in (torch .float , torch .double ):
677
+ tkwargs = {"device" : self .device , "dtype" : dtype }
678
+ train_X1 , train_X2 = (
679
+ torch .rand (5 , 1 , ** tkwargs ),
680
+ torch .rand (5 , 1 , ** tkwargs ),
681
+ )
682
+ train_Y1 = torch .sin (train_X1 )
683
+ train_Y2 = torch .cos (train_X2 )
684
+ m1 = SimpleGPyTorchModel (train_X1 , train_Y1 )
685
+ m2 = SimpleGPyTorchModel (train_X2 , train_Y2 )
686
+ model = SimpleModelListGPyTorchModel (m1 , m2 )
687
+
688
+ # conditioning is not implemented (see ModelListGP for tests)
689
+ with self .assertRaises (NotImplementedError ):
690
+ model .condition_on_observations (
691
+ X = torch .rand (2 , 1 , ** tkwargs ), Y = torch .rand (2 , 2 , ** tkwargs )
692
+ )
0 commit comments