1313# limitations under the License.
1414
1515
16+ from typing import Union
17+
1618import aesara
1719import aesara .tensor as at
1820import numpy as np
@@ -139,10 +141,18 @@ def test_simplex_accuracy():
139141
140142
141143def test_sum_to_1 ():
142- check_vector_transform (tr .sum_to_1 , Simplex (2 ))
143- check_vector_transform (tr .sum_to_1 , Simplex (4 ))
144+ check_vector_transform (tr .univariate_sum_to_1 , Simplex (2 ))
145+ check_vector_transform (tr .univariate_sum_to_1 , Simplex (4 ))
144146
145- check_jacobian_det (tr .sum_to_1 , Vector (Unit , 2 ), at .dvector , np .array ([0 , 0 ]), lambda x : x [:- 1 ])
147+ with pytest .raises (ValueError , match = r"\(ndim_supp\) must not exceed 1" ):
148+ tr .SumTo1 (2 )
149+
150+ check_jacobian_det (
151+ tr .univariate_sum_to_1 , Vector (Unit , 2 ), at .dvector , np .array ([0 , 0 ]), lambda x : x [:- 1 ]
152+ )
153+ check_jacobian_det (
154+ tr .multivariate_sum_to_1 , Vector (Unit , 2 ), at .dvector , np .array ([0 , 0 ]), lambda x : x [:- 1 ]
155+ )
146156
147157
148158def test_log ():
@@ -241,28 +251,36 @@ def test_circular():
241251
242252
243253def test_ordered ():
244- check_vector_transform (tr .ordered , SortedVector (6 ))
254+ check_vector_transform (tr .univariate_ordered , SortedVector (6 ))
255+
256+ with pytest .raises (ValueError , match = r"\(ndim_supp\) must not exceed 1" ):
257+ tr .Ordered (2 )
245258
246- check_jacobian_det (tr .ordered , Vector (R , 2 ), at .dvector , np .array ([0 , 0 ]), elemwise = False )
259+ check_jacobian_det (
260+ tr .univariate_ordered , Vector (R , 2 ), at .dvector , np .array ([0 , 0 ]), elemwise = False
261+ )
262+ check_jacobian_det (
263+ tr .multivariate_ordered , Vector (R , 2 ), at .dvector , np .array ([0 , 0 ]), elemwise = False
264+ )
247265
248- vals = get_values (tr .ordered , Vector (R , 3 ), at .dvector , np .zeros (3 ))
266+ vals = get_values (tr .univariate_ordered , Vector (R , 3 ), at .dvector , np .zeros (3 ))
249267 close_to_logical (np .diff (vals ) >= 0 , True , tol )
250268
251269
252270def test_chain_values ():
253- chain_tranf = tr .Chain ([tr .logodds , tr .ordered ])
271+ chain_tranf = tr .Chain ([tr .logodds , tr .univariate_ordered ])
254272 vals = get_values (chain_tranf , Vector (R , 5 ), at .dvector , np .zeros (5 ))
255273 close_to_logical (np .diff (vals ) >= 0 , True , tol )
256274
257275
258276def test_chain_vector_transform ():
259- chain_tranf = tr .Chain ([tr .logodds , tr .ordered ])
277+ chain_tranf = tr .Chain ([tr .logodds , tr .univariate_ordered ])
260278 check_vector_transform (chain_tranf , UnitSortedVector (3 ))
261279
262280
263281@pytest .mark .xfail (reason = "Fails due to precision issue. Values just close to expected." )
264282def test_chain_jacob_det ():
265- chain_tranf = tr .Chain ([tr .logodds , tr .ordered ])
283+ chain_tranf = tr .Chain ([tr .logodds , tr .univariate_ordered ])
266284 check_jacobian_det (chain_tranf , Vector (R , 4 ), at .dvector , np .zeros (4 ), elemwise = False )
267285
268286
@@ -327,7 +345,14 @@ def check_vectortransform_elementwise_logp(self, model):
327345 jacob_det = transform .log_jac_det (test_array_transf , * x .owner .inputs )
328346 # Original distribution is univariate
329347 if x .owner .op .ndim_supp == 0 :
330- assert model .logp (x , sum = False )[0 ].ndim == x .ndim == (jacob_det .ndim + 1 )
348+ tr_steps = getattr (transform , "transform_list" , [transform ])
349+ transform_keeps_dim = any (
350+ [isinstance (ts , Union [tr .SumTo1 , tr .Ordered ]) for ts in tr_steps ]
351+ )
352+ if transform_keeps_dim :
353+ assert model .logp (x , sum = False )[0 ].ndim == x .ndim == jacob_det .ndim
354+ else :
355+ assert model .logp (x , sum = False )[0 ].ndim == x .ndim == (jacob_det .ndim + 1 )
331356 # Original distribution is multivariate
332357 else :
333358 assert model .logp (x , sum = False )[0 ].ndim == (x .ndim - 1 ) == jacob_det .ndim
@@ -449,7 +474,7 @@ def test_normal_ordered(self):
449474 {"mu" : 0.0 , "sigma" : 1.0 },
450475 size = 3 ,
451476 initval = np .asarray ([- 1.0 , 1.0 , 4.0 ]),
452- transform = tr .ordered ,
477+ transform = tr .univariate_ordered ,
453478 )
454479 self .check_vectortransform_elementwise_logp (model )
455480
@@ -467,7 +492,7 @@ def test_half_normal_ordered(self, sigma, size):
467492 {"sigma" : sigma },
468493 size = size ,
469494 initval = initval ,
470- transform = tr .Chain ([tr .log , tr .ordered ]),
495+ transform = tr .Chain ([tr .log , tr .univariate_ordered ]),
471496 )
472497 self .check_vectortransform_elementwise_logp (model )
473498
@@ -479,7 +504,7 @@ def test_exponential_ordered(self, lam, size):
479504 {"lam" : lam },
480505 size = size ,
481506 initval = initval ,
482- transform = tr .Chain ([tr .log , tr .ordered ]),
507+ transform = tr .Chain ([tr .log , tr .univariate_ordered ]),
483508 )
484509 self .check_vectortransform_elementwise_logp (model )
485510
@@ -501,7 +526,7 @@ def test_beta_ordered(self, a, b, size):
501526 {"alpha" : a , "beta" : b },
502527 size = size ,
503528 initval = initval ,
504- transform = tr .Chain ([tr .logodds , tr .ordered ]),
529+ transform = tr .Chain ([tr .logodds , tr .univariate_ordered ]),
505530 )
506531 self .check_vectortransform_elementwise_logp (model )
507532
@@ -524,7 +549,7 @@ def transform_params(*inputs):
524549 {"lower" : lower , "upper" : upper },
525550 size = size ,
526551 initval = initval ,
527- transform = tr .Chain ([interval , tr .ordered ]),
552+ transform = tr .Chain ([interval , tr .univariate_ordered ]),
528553 )
529554 self .check_vectortransform_elementwise_logp (model )
530555
@@ -536,7 +561,7 @@ def test_vonmises_ordered(self, mu, kappa, size):
536561 {"mu" : mu , "kappa" : kappa },
537562 size = size ,
538563 initval = initval ,
539- transform = tr .Chain ([tr .circular , tr .ordered ]),
564+ transform = tr .Chain ([tr .circular , tr .univariate_ordered ]),
540565 )
541566 self .check_vectortransform_elementwise_logp (model )
542567
@@ -545,7 +570,7 @@ def test_vonmises_ordered(self, mu, kappa, size):
545570 [
546571 (0.0 , 1.0 , (2 ,), tr .simplex ),
547572 (0.5 , 5.5 , (2 , 3 ), tr .simplex ),
548- (np .zeros (3 ), np .ones (3 ), (4 , 3 ), tr .Chain ([tr .sum_to_1 , tr .logodds ])),
573+ (np .zeros (3 ), np .ones (3 ), (4 , 3 ), tr .Chain ([tr .univariate_sum_to_1 , tr .logodds ])),
549574 ],
550575 )
551576 def test_uniform_other (self , lower , upper , size , transform ):
@@ -569,7 +594,11 @@ def test_uniform_other(self, lower, upper, size, transform):
569594 def test_mvnormal_ordered (self , mu , cov , size , shape ):
570595 initval = np .sort (np .random .randn (* shape ))
571596 model = self .build_model (
572- pm .MvNormal , {"mu" : mu , "cov" : cov }, size = size , initval = initval , transform = tr .ordered
597+ pm .MvNormal ,
598+ {"mu" : mu , "cov" : cov },
599+ size = size ,
600+ initval = initval ,
601+ transform = tr .multivariate_ordered ,
573602 )
574603 self .check_vectortransform_elementwise_logp (model )
575604
@@ -598,3 +627,95 @@ def test_discrete_trafo():
598627 with pytest .raises (ValueError ) as err :
599628 pm .Binomial ("a" , n = 5 , p = 0.5 , transform = "log" )
600629 err .match ("Transformations for discrete distributions" )
630+
631+
632+ def test_2d_univariate_ordered ():
633+ with pm .Model () as model :
634+ x_1d = pm .Normal (
635+ "x_1d" ,
636+ mu = [- 3 , - 1 , 1 , 2 ],
637+ sigma = 1 ,
638+ size = (4 ,),
639+ transform = tr .univariate_ordered ,
640+ )
641+ x_2d = pm .Normal (
642+ "x_2d" ,
643+ mu = [- 3 , - 1 , 1 , 2 ],
644+ sigma = 1 ,
645+ size = (10 , 4 ),
646+ transform = tr .univariate_ordered ,
647+ )
648+
649+ log_p = model .compile_logp (sum = False )(
650+ {"x_1d_ordered__" : np .zeros ((4 ,)), "x_2d_ordered__" : np .zeros ((10 , 4 ))}
651+ )
652+ np .testing .assert_allclose (np .tile (log_p [0 ], (10 , 1 )), log_p [1 ])
653+
654+
655+ def test_2d_multivariate_ordered ():
656+ with pm .Model () as model :
657+ x_1d = pm .MvNormal (
658+ "x_1d" ,
659+ mu = [- 1 , 1 ],
660+ cov = np .eye (2 ),
661+ initval = [- 1 , 1 ],
662+ transform = tr .multivariate_ordered ,
663+ )
664+ x_2d = pm .MvNormal (
665+ "x_2d" ,
666+ mu = [- 1 , 1 ],
667+ cov = np .eye (2 ),
668+ size = 2 ,
669+ initval = [[- 1 , 1 ], [- 1 , 1 ]],
670+ transform = tr .multivariate_ordered ,
671+ )
672+
673+ log_p = model .compile_logp (sum = False )(
674+ {"x_1d_ordered__" : np .zeros ((2 ,)), "x_2d_ordered__" : np .zeros ((2 , 2 ))}
675+ )
676+ np .testing .assert_allclose (log_p [0 ], log_p [1 ])
677+
678+
679+ def test_2d_univariate_sum_to_1 ():
680+ with pm .Model () as model :
681+ x_1d = pm .Normal (
682+ "x_1d" ,
683+ mu = [- 3 , - 1 , 1 , 2 ],
684+ sigma = 1 ,
685+ size = (4 ,),
686+ transform = tr .univariate_sum_to_1 ,
687+ )
688+ x_2d = pm .Normal (
689+ "x_2d" ,
690+ mu = [- 3 , - 1 , 1 , 2 ],
691+ sigma = 1 ,
692+ size = (10 , 4 ),
693+ transform = tr .univariate_sum_to_1 ,
694+ )
695+
696+ log_p = model .compile_logp (sum = False )(
697+ {"x_1d_sumto1__" : np .zeros (3 ), "x_2d_sumto1__" : np .zeros ((10 , 3 ))}
698+ )
699+ np .testing .assert_allclose (np .tile (log_p [0 ], (10 , 1 )), log_p [1 ])
700+
701+
702+ def test_2d_multivariate_sum_to_1 ():
703+ with pm .Model () as model :
704+ x_1d = pm .MvNormal (
705+ "x_1d" ,
706+ mu = [- 1 , 1 ],
707+ cov = np .eye (2 ),
708+ transform = tr .multivariate_sum_to_1 ,
709+ )
710+ x_2d = pm .MvNormal (
711+ "x_2d" ,
712+ mu = [- 1 , 1 ],
713+ cov = np .eye (2 ),
714+ size = 2 ,
715+ transform = tr .multivariate_sum_to_1 ,
716+ )
717+
718+ log_p = model .compile_logp (sum = False )(
719+ {"x_1d_sumto1__" : np .zeros (1 ), "x_2d_sumto1__" : np .zeros ((2 , 1 ))}
720+ )
721+ np .testing .assert_allclose (log_p [0 ], log_p [1 ])
0 commit comments