-
Notifications
You must be signed in to change notification settings - Fork 135
Add xtensor broadcast #1489
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: labeled_tensors
Are you sure you want to change the base?
Add xtensor broadcast #1489
Conversation
@ricardoV94 Here's my attempt to rebase on the changes you just force pushed. Looks like mypy is unhappy -- is that something you expected? Other than that, I think this is ready for review. |
Yeah I didn't make mypy pass yet |
pytensor/xtensor/rewriting/shape.py
Outdated
x_tensor = x_tensor.dimshuffle(shuffle_pattern) | ||
|
||
# Now we are aligned with target dims and correct ndim | ||
x_tensor = broadcast_to(x_tensor, out.type.shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This won't work when the output shape is not statically known. The target shape has to be computer symbolically from the symbolic input shapes.
You can test by having an xtensor with shape=(None) for a dim that only that tensor has
71bc4ef
to
41d9be4
Compare
@ricardoV94 I think I have symbolic dimensions working. My solution is more complicated than I think any of us would like, but I don't see a simpler solution. Maybe you will. Should we continue work on this PR, for now, and I will rebase later? |
Here is an idea: def lower_broadcast(fgraph, node):
excluded_dims = node.op.exclude
broadcast_dims = tuple(dim for dim in node.outputs[0].type.dims if dim not in excluded_dims)
all_dims = broadcast_dims + excluded_dims
# align inputs with all_dims like we do in other rewrites
# probably time to refactor this kind of logic into a helper
inp_tensors = []
for inp, out in zip(node.inputs, node.outputs, strict=True)
inp_dims = inp.type.dims
order = tuple(inp_dims.index(dim) if dim in inp_dims else "x" for dim in all_dims)
inp_tensors.append(inp.values.dimshuffle(order))
if not excluded_dims:
out_tensors = pt.broadcast_arrays(*inp_tensors)
else:
all_shape = tuple(pt.broadcast_shape(*inp_tensors))
assert len(all_shape) == len(all_dims)
for inp_tensor, out in zip(inp_tensors, node.outputs):
out_dims = out.type.dims
out_shape = tuple(length for length, dim in zip(all_shape, all_dims) if dim in out_dims)
out_tensors.append(pt.broadcast_to(inp_tensor, out_shape)
new_outs = [as_xtensor(out_tensor, dims=out.type.dims) for out_tensor, out in zip(out_tensors, node.outputs)]
return new_outs Btw the base branch is merged. You can rebase/ start from it. Note that you don't need to open a new PR. You can force-push your changes after cleaning up the branch to your current remote |
@ricardoV94 I've added |
Your version of I'll work on debugging it, but at the moment it's not clear to me whether these is a small error in your implementation or an actual problem with the logic. |
I suspect some wrong assumption on the excluded dims alignment but the general idea should work |
I think the incorrect assumption is that all outputs have the same shape. When exclude is not empty, they don't, in general. |
Actually there's a logical flaw. Two inputs could have an excluded dim with the same name but different length, in which case they shouldn't be aligned for the broadcast shape. We should add that as a test. Still the logic for each output should be something like |
I didn't assume that, the dimshuffle was supposed to take care of that so that things were put in different axis for broadcasting. Still as I just wrote there was a wrong assumption that you could align shared excluded dims. They don't even come out in a uniform order do they? |
I don't think this logical flaw is why the tests are failing though. We should test that case as well |
@AllenDowney does this work? If not, what case fails? from pytensor.xtensor.type import xtensor
from pytensor.xtensor.math import second
def broadcast(array, *arrays, exclude=()):
if isinstance(exclude, str):
exclude = (exclude,)
def sum_excluded_dims(array):
if not exclude:
return array
dims = array.dims
array_exclude = tuple(e for e in exclude if e in dims)
if not array_exclude:
return array
return array.sum(array_exclude)
def align_excluded_dims(array):
if not exclude:
return array
dims = array.dims
array_exclude = tuple(e for e in exclude if e in dims)
if not array_exclude:
return array
return array.transpose(..., *array_exclude)
if not arrays:
return array
# Find broadcast shape by doing nested second after excluding via `sum`
# The sum operation will be removed in rewrites, since only the shape matters
broadcast_array = sum_excluded_dims(array)
for other_array in arrays:
# second is equivalent no `np.broadcast_arrays(x, y)[1]`
broadcast_array = second(broadcast_array, sum_excluded_dims(other_array))
# Broadcast each original array with the broadcast_array
# We further align the excluded dims according to the order given by the user, like xarray does
return tuple(second(broadcast_array, align_excluded_dims(arr)) for arr in (array, *arrays))
x = xtensor(dims=("a", "b", "c"))
y = xtensor(dims=("a", "d"))
z = xtensor(dims=("a", "f", "b"))
for out in broadcast(x, y, z, exclude=("b", "f")):
print(out.dims)
# ('a', 'c', 'd', 'b')
# ('a', 'c', 'd')
# ('a', 'c', 'd', 'b', 'f') |
This replaces #1486. This one is based on a rebased
labeled_tensor
branch📚 Documentation preview 📚: https://pytensor--1489.org.readthedocs.build/en/1489/