@@ -764,10 +764,10 @@ def setup_method(self):
764764 Max .debug = 0
765765 Argmax .debug = 0
766766
767- def test_basic (self ):
767+ def test_basic_0 (self ):
768768 # dbt: for some reason, Argmax does not work when I pass: n = as_tensor_variable(5.0)
769- n = as_tensor_variable (5.0 )
770- v , i = eval_outputs (max_and_argmax (n ))
769+ n = as_tensor_variable (5 )
770+ v , i = eval_outputs (max_and_argmax (n , axis = () ))
771771 assert v == 5.0
772772 assert i == 0
773773 assert i .dtype == "int64"
@@ -809,11 +809,7 @@ def test_basic_2(self, axis, np_axis):
809809 v_shape , i_shape = eval_outputs ([vt .shape , it .shape ])
810810 assert tuple (v_shape ) == vt .type .shape
811811 assert tuple (i_shape ) == it .type .shape
812- # Test values
813- v , i = eval_outputs ([vt , it ])
814- assert i .dtype == "int64"
815- assert np .all (v == np_max )
816- assert np .all (i == np_argm )
812+ # Test valuesgi
817813
818814 @pytest .mark .parametrize (
819815 "axis,np_axis" ,
@@ -1372,27 +1368,30 @@ def _grad_list(self):
13721368 data = random (2 , 3 )
13731369 for fct in [max_and_argmax , max , min ]:
13741370 utt .verify_grad (lambda v : fct (v , axis = [0 , 1 ]), [data ])
1375- # n = as_tensor_variable(data)
1376- # check_grad_max(data, eval_outputs(grad(max_and_argmax(n,
1377- # axis=1)[0], n)),axis=1)
1371+ n = as_tensor_variable (data )
1372+ check_grad_max (
1373+ data , eval_outputs (grad (max_and_argmax (n , axis = 1 )[0 ], n )), axis = 1
1374+ )
13781375
13791376 def test_uint (self ):
13801377 for dtype in ("uint8" , "uint16" , "uint32" , "uint64" ):
13811378 itype = np .iinfo (dtype )
13821379 data = np .array ([itype .min + 3 , itype .min , itype .max - 5 , itype .max ], dtype )
13831380 n = as_tensor_variable (data )
13841381 assert min (n ).dtype == dtype
1382+ # print(min(n).owner.inputs[1].acc_dtype)
13851383 i = eval_outputs (min (n ))
13861384 # pytensor.dprint(n)
1387- for x in n :
1388- print (x .eval ())
1385+ # for x in n:
1386+ # print(x.eval())
13891387 print (i )
13901388 print (itype .min )
13911389 print ()
13921390 assert i == itype .min
1393- assert max (n ).dtype == dtype
1394- i = eval_outputs (max (n ))
1395- assert i == itype .max
1391+ # assert max(n).dtype == dtype
1392+ # i = eval_outputs(max(n))
1393+ # assert i == itype.max
1394+ # assert 0
13961395
13971396 def test_bool (self ):
13981397 data = np .array ([True , False ], "bool" )
0 commit comments