@@ -149,6 +149,8 @@ class MaxAndArgmax(COp):
149149
150150 def __init__ (self , axis ):
151151 assert isinstance (axis , tuple | list )
152+ # print(axis)
153+ # assert 0
152154 self .axis = tuple (axis )
153155
154156 def get_params (self , node ):
@@ -343,6 +345,208 @@ def grad(self, inp, grads):
343345 return (g_x ,)
344346
345347
348+ class TensorMax (COp ):
349+ """
350+ Calculate the max over a given axis or over all axes.
351+
352+ """
353+
354+ nin = 2 # tensor, axis
355+ nout = 1 # max val
356+ E_axis = "invalid axis"
357+ params_type = Generic ()
358+ __props__ = ("axis" ,)
359+ _f16_ok = True
360+
361+ def __init__ (self , axis ):
362+ assert isinstance (axis , tuple | list )
363+ self .axis = tuple (axis )
364+
365+ def get_params (self , node ):
366+ return self .axis
367+
368+ def make_node (self , x ):
369+ x = as_tensor_variable (x )
370+
371+ # Keep the original shapes for axes on which we do not perform the max/argmax.
372+ all_axes = set (self .axis )
373+ inputs = [x ]
374+ out_shape = tuple (s for i , s in enumerate (x .type .shape ) if i not in all_axes )
375+ outputs = [
376+ tensor (dtype = x .type .dtype , shape = out_shape , name = "max" ),
377+ ]
378+ return Apply (self , inputs , outputs )
379+
380+ def prepare_node (self , node , storage_map , compute_map , impl ):
381+ if len (node .inputs ) == 2 :
382+ raise ValueError (
383+ "You are trying to compile a graph with an old Argmax node. Either reoptimize your graph or rebuild it to get the new node format."
384+ )
385+
386+ def perform (self , node , inp , outs ):
387+ x = inp [0 ]
388+ axes = self .axis
389+ # max, max_idx = outs
390+ (max ,) = outs
391+ if axes is None :
392+ axes = tuple (range (x .ndim ))
393+ else :
394+ axes = tuple (int (ax ) for ax in axes )
395+ max [0 ] = _asarray (np .max (x , axes ), dtype = node .outputs [0 ].dtype )
396+ # # Numpy does not support multiple axes for argmax
397+ # # Work around
398+ # keep_axes = np.array([i for i in range(x.ndim) if i not in axes], dtype="int64")
399+ # # Not-reduced axes in front
400+ # transposed_x = np.transpose(x, np.concatenate((keep_axes, axes)))
401+ # kept_shape = transposed_x.shape[: len(keep_axes)]
402+ # reduced_shape = transposed_x.shape[len(keep_axes) :]
403+
404+ # # Numpy.prod returns 1.0 when arg is empty, so we cast it to int64
405+ # # Otherwise reshape would complain citing float arg
406+ # new_shape = (*kept_shape, np.prod(reduced_shape, dtype="int64"))
407+ # reshaped_x = transposed_x.reshape(new_shape)
408+
409+ # max_idx[0] = _asarray(np.argmax(reshaped_x, axis=-1), dtype="int64")
410+
411+ def c_code (self , node , name , inp , out , sub ):
412+ if len (self .axis ) != 1 and len (self .axis ) != node .inputs [0 ].ndim :
413+ raise NotImplementedError (
414+ "NumPy C-API can compute max only for 1 axis or for all axes."
415+ )
416+ x = inp [0 ]
417+ axis = sub ["params" ]
418+ # max, argmax = out
419+ (max ,) = out
420+ fail = sub ["fail" ]
421+ ret = """
422+ #if PY_MAJOR_VERSION >= 3
423+ #ifndef PyInt_AS_LONG
424+ #define PyInt_AS_LONG PyLong_AS_LONG
425+ #endif
426+ #endif
427+
428+ int axis;
429+
430+ if (PyTuple_GET_SIZE(%(axis)s) == PyArray_NDIM(%(x)s)) {
431+ axis = NPY_MAXDIMS;
432+ } else if(PyTuple_GET_SIZE(%(axis)s) == 1) {
433+ PyObject* axis_object = PyTuple_GET_ITEM(%(axis)s, 0);
434+ axis = (int)PyInt_AS_LONG(axis_object);
435+ if (axis > PyArray_NDIM(%(x)s)-1 || axis < -PyArray_NDIM(%(x)s)) {
436+ PyErr_SetString(PyExc_ValueError,
437+ "TensorMax: bad axis argument");
438+ %(fail)s
439+ }
440+ } else {
441+ PyErr_SetString(PyExc_NotImplementedError,
442+ "TensorMax: NumPy C-API can compute max only for 1 axis or for all axes.");
443+ %(fail)s
444+ }
445+
446+ Py_CLEAR(%(max)s);
447+
448+ %(max)s = (PyArrayObject*)PyArray_Max(%(x)s, axis, NULL);
449+ if (%(max)s == NULL) {
450+ %(fail)s;
451+ }
452+ if (!PyArray_CheckExact(%(max)s)) {
453+ %(max)s = (PyArrayObject*)PyArray_FromAny((PyObject*)%(max)s, NULL, 0, 0, NPY_ARRAY_ENSUREARRAY, NULL);
454+ if(%(max)s == NULL){
455+ %(fail)s;
456+ }
457+ }
458+ """
459+ return ret % locals ()
460+
461+ def c_code_cache_version (self ):
462+ return (5 ,)
463+
464+ def infer_shape (self , fgraph , node , shapes ):
465+ ishape = shapes [0 ]
466+ rval = tuple (
467+ ishape [i ]
468+ for (i , b ) in enumerate (node .inputs [0 ].type .broadcastable )
469+ if i not in self .axis
470+ )
471+ return [rval ]
472+
473+ def R_op (self , inputs , eval_points ):
474+ if eval_points [0 ] is None :
475+ return [None , None ]
476+
477+ if len (self .axis ) != 1 :
478+ raise ValueError ("R_op supported for arg_max only for one axis!" )
479+ if self .axis [0 ] > 1 :
480+ raise ValueError ("R_op supported for arg_max only when axis is 0 or 1" )
481+ if inputs [0 ].ndim != 2 :
482+ raise ValueError ("R_op supported for arg_max only when input is a matrix" )
483+ # max_vals, max_pos = self.make_node(*inputs).outputs
484+ # max_vals = self.make_node(*inputs).outputs
485+ if self .axis [0 ] == 0 :
486+ return [eval_points [0 ][arange (eval_points [0 ].shape [1 ])], None ]
487+ else :
488+ return [eval_points [0 ][arange (eval_points [0 ].shape [0 ])], None ]
489+
490+ def grad (self , inp , grads ):
491+ # The strict sense mathematical gradient of the maximum function is
492+ # not calculated here for it is not defined at every point where some
493+ # coordinates are identical. However, since the latter set has null
494+ # Lebesgue measure, the result may be interpreted as weak gradient.
495+
496+ # @note: This function should work correctly for L{vector}s.
497+ # (x, y), (gz, gw)
498+ # gz*dz/dx + gw*dw/dx, gz*dz/dy + gw*dw/dy
499+ # gMax * dMax/dx + gArgMax * dArgMax/dx,
500+ # gMax * dMax/daxis + gArgMax * dArgMax/daxis
501+ # g_max has one less dimension than x, so you need to complete
502+ # g_max to x's shape when axis=0 the broadcasting mechanism
503+ # does it automatically
504+ x = inp [0 ]
505+ axis = as_tensor_variable (self .axis )
506+ # g_max, g_max_idx = grads
507+ (g_max ,) = grads
508+
509+ g_max_disconnected = isinstance (g_max .type , DisconnectedType )
510+ # g_max_idx_disconnected = isinstance(g_max_idx.type, DisconnectedType)
511+
512+ # # if the op is totally disconnected, so are its inputs
513+ # if g_max_disconnected and g_max_idx_disconnected:
514+ # return [DisconnectedType()(), DisconnectedType()()]
515+
516+ # if the op is totally disconnected, so are its inputs
517+ if g_max_disconnected :
518+ return [DisconnectedType ()()]
519+
520+ # if the max is disconnected but the argmax is not,
521+ # the gradient on its inputs is zero
522+ # if g_max_disconnected:
523+ # return [x.zeros_like()]
524+ if NoneConst .equals (axis ):
525+ axis_ = list (range (x .ndim ))
526+ else :
527+ axis_ = axis
528+ xmax = max (x , axis_ )
529+
530+ # Raise the g_max and xmax to the same number of dim as the input.
531+ pattern = []
532+ out_dim = 0
533+ if NoneConst .equals (axis ):
534+ # We are taking the max/argmax over all dimensions.
535+ axis = None
536+ for i in range (x .ndim ):
537+ if axis is None or i in axis .data :
538+ pattern .append ("x" )
539+ else :
540+ pattern .append (out_dim )
541+ out_dim += 1
542+ g_max_pad = DimShuffle (g_max .broadcastable , pattern )(g_max )
543+ xmax_pad = DimShuffle (xmax .broadcastable , pattern )(xmax )
544+
545+ # Set the grad to the correct position.
546+ g_x = eq (xmax_pad , x ) * g_max_pad
547+ return (g_x ,)
548+
549+
346550class Argmax (COp ):
347551 """
348552 Calculate the argmax over a given axis or over all axes.
@@ -357,8 +561,10 @@ class Argmax(COp):
357561 params_type = ParamsType (c_axis = ps .int64 )
358562
359563 def __init__ (self , axis ):
360- if axis is not None :
361- axis = tuple (axis )
564+ # if axis is not None:
565+ # axis = tuple(axis)
566+ assert isinstance (axis , tuple | list )
567+ # print(axis)
362568 self .axis = tuple (axis )
363569
364570 def get_params (self , node ):
@@ -395,6 +601,8 @@ def perform(self, node, inp, outs):
395601 (max_idx ,) = outs
396602 if axes is None :
397603 axes = tuple (range (x .ndim ))
604+ else :
605+ axes = tuple (int (ax ) for ax in axes )
398606
399607 # Numpy does not support multiple axes for argmax
400608 # Work around
@@ -477,7 +685,7 @@ def grad(self, inp, grads):
477685
478686
479687@_vectorize_node .register (Argmax )
480- @_vectorize_node .register (MaxAndArgmax )
688+ # @_vectorize_node.register(MaxAndArgmax)
481689def vectorize_argmax_node (op , node , batch_x ):
482690 core_ndim = node .inputs [0 ].type .ndim
483691 batch_ndim = batch_x .type .ndim - core_ndim
@@ -600,7 +808,9 @@ def max_and_argmax(a, axis=None, keepdims=False):
600808 axis = check_and_normalize_axes (a , axis )
601809 if len (axis ) == 0 :
602810 axis = list (range (a .type .ndim ))
603- out , argout = MaxAndArgmax (axis )(a )
811+ out = TensorMax (axis )(a )
812+ argout = Argmax (axis )(a )
813+ # out, argout = MaxAndArgmax(axis)(a)
604814
605815 if keepdims :
606816 out = makeKeepDims (a , out , axis )
0 commit comments