diff --git a/arraycontext/impl/pytato/fake_numpy.py b/arraycontext/impl/pytato/fake_numpy.py index 4dad159f..f2e3d2e8 100644 --- a/arraycontext/impl/pytato/fake_numpy.py +++ b/arraycontext/impl/pytato/fake_numpy.py @@ -73,7 +73,8 @@ def __getattr__(self, name): def zeros_like(self, ary): def _zeros_like(array): - return self._array_context.zeros(array.shape, array.dtype) + return self._array_context.zeros( + array.shape, array.dtype).copy(axes=array.axes, tags=array.tags) return self._array_context._rec_map_container( _zeros_like, ary, default_scalar=0) @@ -83,7 +84,8 @@ def ones_like(self, ary): def full_like(self, ary, fill_value): def _full_like(subary): - return pt.full(subary.shape, fill_value, subary.dtype) + return pt.full(subary.shape, fill_value, subary.dtype).copy( + axes=subary.axes, tags=subary.tags) return self._array_context._rec_map_container( _full_like, ary, default_scalar=fill_value)