2727
2828import pymc as pm
2929
30+ from pymc .data import is_minibatch
3031from pymc .pytensorf import GeneratorOp , floatX
3132from pymc .tests .helpers import SeededTest , select_by_precision
3233
@@ -696,15 +697,10 @@ def test_common_errors(self):
696697
697698 def test_mixed1 (self ):
698699 with pm .Model ():
699- data = np .random .rand (10 , 20 , 30 , 40 , 50 )
700- mb = pm .Minibatch (data , [2 , None , 20 , Ellipsis , 10 ])
701- pm .Normal ("n" , observed = mb , total_size = (10 , None , 30 , Ellipsis , 50 ))
702-
703- def test_mixed2 (self ):
704- with pm .Model ():
705- data = np .random .rand (10 , 20 , 30 , 40 , 50 )
706- mb = pm .Minibatch (data , [2 , None , 20 ])
707- pm .Normal ("n" , observed = mb , total_size = (10 , None , 30 ))
700+ data = np .random .rand (10 , 20 )
701+ mb = pm .Minibatch (data , batch_size = 5 )
702+ v = pm .Normal ("n" , observed = mb , total_size = 10 )
703+ assert pm .logp (v , 1 ) is not None , "Check index is allowed in graph"
708704
709705 def test_free_rv (self ):
710706 with pm .Model () as model4 :
@@ -719,51 +715,28 @@ def test_free_rv(self):
719715
720716@pytest .mark .usefixtures ("strict_float32" )
721717class TestMinibatch :
722- data = np .random .rand (30 , 10 , 40 , 10 , 50 )
718+ data = np .random .rand (30 , 10 )
723719
724720 def test_1d (self ):
725- mb = pm .Minibatch (self .data , 20 )
726- assert mb .eval ().shape == (20 , 10 , 40 , 10 , 50 )
727-
728- def test_2d (self ):
729- mb = pm .Minibatch (self .data , [(10 , 42 ), (4 , 42 )])
730- assert mb .eval ().shape == (10 , 4 , 40 , 10 , 50 )
731-
732- @pytest .mark .parametrize (
733- "batch_size, expected" ,
734- [
735- ([(10 , 42 ), None , (4 , 42 )], (10 , 10 , 4 , 10 , 50 )),
736- ([(10 , 42 ), Ellipsis , (4 , 42 )], (10 , 10 , 40 , 10 , 4 )),
737- ([(10 , 42 ), None , Ellipsis , (4 , 42 )], (10 , 10 , 40 , 10 , 4 )),
738- ([10 , None , Ellipsis , (4 , 42 )], (10 , 10 , 40 , 10 , 4 )),
739- ],
740- )
741- def test_special_batch_size (self , batch_size , expected ):
742- mb = pm .Minibatch (self .data , batch_size )
743- assert mb .eval ().shape == expected
744-
745- def test_cloning_available (self ):
746- gop = pm .Minibatch (np .arange (100 ), 1 )
747- res = gop ** 2
748- shared = pytensor .shared (np .array ([10 ]))
749- res1 = pytensor .clone_replace (res , {gop : shared })
750- f = pytensor .function ([], res1 )
751- assert f () == np .array ([100 ])
752-
753- def test_align (self ):
754- m = pm .Minibatch (np .arange (1000 ), 1 , random_seed = 1 )
755- n = pm .Minibatch (np .arange (1000 ), 1 , random_seed = 1 )
756- f = pytensor .function ([], [m , n ])
757- n .eval () # not aligned
758- a , b = zip (* (f () for _ in range (1000 )))
759- assert a != b
760- pm .align_minibatches ()
761- a , b = zip (* (f () for _ in range (1000 )))
762- assert a == b
763- n .eval () # not aligned
764- pm .align_minibatches ([m ])
765- a , b = zip (* (f () for _ in range (1000 )))
766- assert a != b
767- pm .align_minibatches ([m , n ])
768- a , b = zip (* (f () for _ in range (1000 )))
769- assert a == b
721+ mb = pm .Minibatch (self .data , batch_size = 20 )
722+ assert is_minibatch (mb )
723+ assert mb .eval ().shape == (20 , 10 )
724+
725+ def test_allowed (self ):
726+ mb = pm .Minibatch (at .as_tensor (self .data ).astype (int ), batch_size = 20 )
727+ assert is_minibatch (mb )
728+
729+ def test_not_allowed (self ):
730+ with pytest .raises (ValueError , match = "not valid for Minibatch" ):
731+ mb = pm .Minibatch (at .as_tensor (self .data ) * 2 , batch_size = 20 )
732+
733+ def test_not_allowed2 (self ):
734+ with pytest .raises (ValueError , match = "not valid for Minibatch" ):
735+ mb = pm .Minibatch (self .data , at .as_tensor (self .data ) * 2 , batch_size = 20 )
736+
737+ def test_assert (self ):
738+ with pytest .raises (
739+ AssertionError , match = r"All variables shape\[0\] in Minibatch should be equal"
740+ ):
741+ d1 , d2 = pm .Minibatch (self .data , self .data [::2 ], batch_size = 20 )
742+ d1 .eval ()
0 commit comments