Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 43 additions & 17 deletions pytensor/graph/destroyhandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import itertools
from collections import deque

import pytensor
from pytensor.configdefaults import config
from pytensor.graph.basic import Constant
from pytensor.graph.features import AlreadyThere, Bookkeeper
Expand Down Expand Up @@ -223,7 +222,7 @@
return droot, impact, root_destroyer


def fast_inplace_check(fgraph, inputs):
def inplace_candidates(fgraph, inputs, protected_inputs=None):
"""
Return the variables in inputs that are possible candidate for as inputs of
inplace operation.
Expand All @@ -234,22 +233,49 @@
Inputs Variable that you want to use as inplace destination.

"""
Supervisor = pytensor.compile.function.types.Supervisor
protected_inputs = list(
itertools.chain.from_iterable(
f.protected for f in fgraph._features if isinstance(f, Supervisor)
if protected_inputs is None:
from pytensor.compile.function.types import Supervisor

Check warning on line 237 in pytensor/graph/destroyhandler.py

View check run for this annotation

Codecov / codecov/patch

pytensor/graph/destroyhandler.py#L237

Added line #L237 was not covered by tests

protected_inputs = set(

Check warning on line 239 in pytensor/graph/destroyhandler.py

View check run for this annotation

Codecov / codecov/patch

pytensor/graph/destroyhandler.py#L239

Added line #L239 was not covered by tests
itertools.chain.from_iterable(
f.protected for f in fgraph._features if isinstance(f, Supervisor)
)
)
)
protected_inputs.extend(fgraph.outputs)

inputs = [
i
for i in inputs
if not isinstance(i, Constant)
and not fgraph.has_destroyers([i])
and i not in protected_inputs
]
return inputs
protected_inputs.update(fgraph.outputs)

Check warning on line 244 in pytensor/graph/destroyhandler.py

View check run for this annotation

Codecov / codecov/patch

pytensor/graph/destroyhandler.py#L244

Added line #L244 was not covered by tests

has_destroyers = fgraph.has_destroyers
view_i = fgraph.destroy_handler.view_i
candidate_roots = {}
candidate_inputs = []
for inp in inputs:
if isinstance(inp, Constant):
# Can't inplace on constants.
continue

# Find the root of the view chain, and while traversing check if it passes on any protected inputs.
view_of_protected = False
root = inp
try:
while True:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand this. If root is in protected_inputs, how do we ever exit this loop? It doesn't seem like you're looping over anything

Copy link
Member Author

@ricardoV94 ricardoV94 Jul 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

view_i[root] for a root variable is None, so next time it tries view_i[None] and fails with a KeyError.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

70% confidence

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

view_i doesn't have entries if a variable is already a root, checked now

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh it's reassigning the root variable inside the while loop. That's what I was missing.

if root in protected_inputs:
view_of_protected = True
root = view_i[root]
except KeyError:
pass

if root in candidate_roots:
# Another input views on the same root, we can't destroy either
if (invalid_candidate := candidate_roots[root]) is not None:
# Invalidate the previous candidate
candidate_inputs.remove(invalid_candidate)
candidate_roots[root] = None
elif not view_of_protected and not has_destroyers([inp]):
candidate_inputs.append(inp)
candidate_roots[root] = inp
else:
candidate_roots[root] = None

return candidate_inputs


class DestroyHandler(Bookkeeper):
Expand Down
124 changes: 63 additions & 61 deletions pytensor/tensor/rewriting/blockwise.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import itertools

from pytensor.compile import Supervisor
from pytensor.compile.mode import optdb
from pytensor.graph import Constant, node_rewriter
from pytensor.graph.destroyhandler import inplace_candidates
from pytensor.graph.replace import vectorize_node
from pytensor.graph.rewriting.basic import copy_stack_trace, in2out, out2in
from pytensor.graph.rewriting.basic import copy_stack_trace, out2in
from pytensor.tensor.basic import Alloc, ARange, alloc, shape_padleft
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.math import Dot
Expand All @@ -13,6 +11,7 @@
register_specialize,
register_stabilize,
)
from pytensor.tensor.rewriting.elemwise import InplaceGraphOptimizer
from pytensor.tensor.shape import Reshape
from pytensor.tensor.subtensor import (
AdvancedIncSubtensor,
Expand Down Expand Up @@ -262,74 +261,77 @@
return [x[(*none_slices, *core_idxs)]]


@node_rewriter(tracks=[Blockwise], inplace=True)
def blockwise_inplace(fgraph, node):
blockwise_op = node.op

if blockwise_op.destroy_map:
# Op already has inplace
return

# Find out valid inputs for inplacing
batch_ndim = blockwise_op.batch_ndim(node)
out_batch_bcast = node.outputs[0].type.broadcastable[:batch_ndim]

protected_inputs = [
f.protected for f in fgraph._features if isinstance(f, Supervisor)
]
protected_inputs = list(itertools.chain.from_iterable(protected_inputs))
protected_inputs.extend(fgraph.outputs)
allowed_inplace_inputs = [
idx
for idx, inp in enumerate(node.inputs)
if
(
# Constants would need to be recreated every time if inplaced
not isinstance(inp, Constant)
# We can only inplace on inputs that are not being broadcasted
# As those are reused across iterations of Blockwise
and node.inputs[idx].type.broadcastable[:batch_ndim] == out_batch_bcast
# Inputs that are marked as protected or destroyed can't be inplaced
and not fgraph.has_destroyers([inp])
and inp not in protected_inputs
class InplaceBlockwiseOptimizer(InplaceGraphOptimizer):
op = Blockwise

def filter_candidate_pairs(self, fgraph, node, protected_inputs):
blockwise_op = node.op
batch_ndim = blockwise_op.batch_ndim(node)
out_batch_bcast = node.outputs[0].type.broadcastable[:batch_ndim]
inputs = node.inputs

candidate_inputs = set(
inplace_candidates(
fgraph,
[
inp
for inp in inputs
if inp.type.broadcastable[:batch_ndim] == out_batch_bcast
],
protected_inputs=protected_inputs,
)
)
]

if not allowed_inplace_inputs:
return None
allowed_inplace_inputs = [
i for i, inp in enumerate(inputs) if inp in candidate_inputs
]
destroy_map = blockwise_op.core_op.inplace_on_inputs(
allowed_inplace_inputs=allowed_inplace_inputs
).destroy_map

if not destroy_map:
return []

outputs = node.outputs
return [
((out_idx, outputs[out_idx]), (inp_idx, inputs[inp_idx]))
for out_idx, inp_idxs in destroy_map.items()
for inp_idx in inp_idxs
]

inplace_core_op = blockwise_op.core_op.inplace_on_inputs(
allowed_inplace_inputs=allowed_inplace_inputs
)
def create_inplace_node(self, node, inplace_pattern):
blockwise_op = node.op
allowed_inplace_inputs = tuple(v[0] for v in inplace_pattern.values())
inplace_core_op = blockwise_op.core_op.inplace_on_inputs(
allowed_inplace_inputs=allowed_inplace_inputs
)

if not inplace_core_op.destroy_map:
return None
if not inplace_core_op.destroy_map:
return node

Check warning on line 310 in pytensor/tensor/rewriting/blockwise.py

View check run for this annotation

Codecov / codecov/patch

pytensor/tensor/rewriting/blockwise.py#L310

Added line #L310 was not covered by tests

# Check Op is not trying to inplace on non-candidate inputs
for destroyed_inputs in inplace_core_op.destroy_map.values():
for destroyed_input in destroyed_inputs:
if destroyed_input not in allowed_inplace_inputs:
raise ValueError(
f"Op {blockwise_op.core_op} destroy_map does not respect allowed_inplace_inputs {allowed_inplace_inputs}"
)
# Check Op is not trying to inplace on non-candidate inputs
for destroyed_inputs in inplace_core_op.destroy_map.values():
for destroyed_input in destroyed_inputs:
if destroyed_input not in allowed_inplace_inputs:
raise ValueError(

Check warning on line 316 in pytensor/tensor/rewriting/blockwise.py

View check run for this annotation

Codecov / codecov/patch

pytensor/tensor/rewriting/blockwise.py#L316

Added line #L316 was not covered by tests
f"Op {blockwise_op.core_op} destroy_map does not respect allowed_inplace_inputs {allowed_inplace_inputs}"
)

# Recreate core_op with inplace
inplace_blockwise_op = Blockwise(
core_op=inplace_core_op,
signature=blockwise_op.signature,
name=blockwise_op.name,
gufunc_spec=blockwise_op.gufunc_spec,
destroy_map=inplace_core_op.destroy_map,
)
# Recreate core_op with inplace
inplace_blockwise_op = type(blockwise_op)(
core_op=inplace_core_op,
signature=blockwise_op.signature,
name=blockwise_op.name,
gufunc_spec=blockwise_op.gufunc_spec,
destroy_map=inplace_core_op.destroy_map,
)

out = inplace_blockwise_op.make_node(*node.inputs).outputs
copy_stack_trace(node.outputs, out)
return out
return inplace_blockwise_op.make_node(*node.inputs)


optdb.register(
"blockwise_inplace",
in2out(blockwise_inplace),
InplaceBlockwiseOptimizer(),
"fast_run",
"inplace",
position=50.1,
Expand Down
Loading