22
33from pytensor .link .jax .dispatch import jax_funcify
44from pytensor .tensor .blas import BatchedDot
5- from pytensor .tensor .math import Dot , MaxAndArgmax
5+ from pytensor .tensor .math import Argmax , Dot , Max
66from pytensor .tensor .nlinalg import (
77 SVD ,
88 Det ,
@@ -104,18 +104,73 @@ def batched_dot(a, b):
104104 return batched_dot
105105
106106
107- @jax_funcify .register (MaxAndArgmax )
108- def jax_funcify_MaxAndArgmax (op , ** kwargs ):
107+ # @jax_funcify.register(Max)
108+ # @jax_funcify.register(Argmax)
109+ # def jax_funcify_MaxAndArgmax(op, **kwargs):
110+ # axis = op.axis
111+
112+ # def maxandargmax(x, axis=axis):
113+ # if axis is None:
114+ # axes = tuple(range(x.ndim))
115+ # else:
116+ # axes = tuple(int(ax) for ax in axis)
117+
118+ # max_res = jnp.max(x, axis)
119+
120+ # # NumPy does not support multiple axes for argmax; this is a
121+ # # work-around
122+ # keep_axes = jnp.array(
123+ # [i for i in range(x.ndim) if i not in axes], dtype="int64"
124+ # )
125+ # # Not-reduced axes in front
126+ # transposed_x = jnp.transpose(
127+ # x, jnp.concatenate((keep_axes, jnp.array(axes, dtype="int64")))
128+ # )
129+ # kept_shape = transposed_x.shape[: len(keep_axes)]
130+ # reduced_shape = transposed_x.shape[len(keep_axes) :]
131+
132+ # # Numpy.prod returns 1.0 when arg is empty, so we cast it to int64
133+ # # Otherwise reshape would complain citing float arg
134+ # new_shape = (
135+ # *kept_shape,
136+ # jnp.prod(jnp.array(reduced_shape, dtype="int64"), dtype="int64"),
137+ # )
138+ # reshaped_x = transposed_x.reshape(new_shape)
139+
140+ # max_idx_res = jnp.argmax(reshaped_x, axis=-1).astype("int64")
141+
142+ # return max_res, max_idx_res
143+
144+ # return maxandargmax
145+
146+
147+ @jax_funcify .register (Max )
148+ def jax_funcify_Max (op , ** kwargs ):
109149 axis = op .axis
110150
111- def maxandargmax (x , axis = axis ):
151+ def max (x , axis = axis ):
152+ # if axis is None:
153+ # axes = tuple(range(x.ndim))
154+ # else:
155+ # axes = tuple(int(ax) for ax in axis)
156+
157+ max_res = jnp .max (x , axis )
158+
159+ return max_res
160+
161+ return max
162+
163+
164+ @jax_funcify .register (Argmax )
165+ def jax_funcify_Argmax (op , ** kwargs ):
166+ axis = op .axis
167+
168+ def argmax (x , axis = axis ):
112169 if axis is None :
113170 axes = tuple (range (x .ndim ))
114171 else :
115172 axes = tuple (int (ax ) for ax in axis )
116173
117- max_res = jnp .max (x , axis )
118-
119174 # NumPy does not support multiple axes for argmax; this is a
120175 # work-around
121176 keep_axes = jnp .array (
@@ -138,6 +193,6 @@ def maxandargmax(x, axis=axis):
138193
139194 max_idx_res = jnp .argmax (reshaped_x , axis = - 1 ).astype ("int64" )
140195
141- return max_res , max_idx_res
196+ return max_idx_res
142197
143- return maxandargmax
198+ return argmax
0 commit comments