-
Notifications
You must be signed in to change notification settings - Fork 139
Do not skip validation between consecutive Elemwise inplace replacements #1494
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
fdc5153
to
b05557e
Compare
Codecov ReportAttention: Patch coverage is
❌ Your patch status has failed because the patch coverage (90.53%) is below the target coverage (100.00%). You can increase the patch coverage or adjust the target coverage. Additional details and impacted files@@ Coverage Diff @@
## main #1494 +/- ##
==========================================
+ Coverage 81.98% 82.01% +0.03%
==========================================
Files 231 231
Lines 52192 52299 +107
Branches 9185 9207 +22
==========================================
+ Hits 42790 42894 +104
- Misses 7094 7095 +1
- Partials 2308 2310 +2
🚀 New features to boost your workflow:
|
65f2e35
to
6475c85
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR ensures validation after each in-place rewrite in the Elemwise optimizer, refactors shared logic for Blockwise in-place rewrites, and adds regression tests for both cases.
- Enforce full graph validation between consecutive Elemwise in-place replacements
- Refactor Blockwise in-place optimizer to extend the shared
InplaceGraphOptimizer
- Add tests for regression in Elemwise and partial in-place behavior in Blockwise
Reviewed Changes
Copilot reviewed 4 out of 5 changed files in this pull request and generated no comments.
File | Description |
---|---|
tests/tensor/test_blockwise.py | Add test_partial_inplace to cover Blockwise in-place behavior |
tests/tensor/rewriting/test_elemwise.py | Add test_InplaceElemwiseOptimizer_bug regression test for #1420 |
pytensor/tensor/rewriting/blockwise.py | Refactor Blockwise in-place rewrite using InplaceGraphOptimizer |
pytensor/graph/destroyhandler.py | Introduce inplace_candidates helper for unified candidate logic |
Comments suppressed due to low confidence (4)
tests/tensor/rewriting/test_elemwise.py:1535
- [nitpick] Test function names should use snake_case. Consider renaming this to
test_inplace_elemwise_optimizer_bug
for consistency with pytest conventions.
def test_InplaceElemwiseOptimizer_bug():
tests/tensor/rewriting/test_elemwise.py:1544
- The test references
Elemwise
but there is no corresponding import. Please addfrom pytensor.tensor.rewriting.elemwise import Elemwise
at the top of the file.
out1, out2 = Elemwise(ps.Composite([z1, z2], [z1 + z2, z2 - z1]))(z[1:], z[:-1])
pytensor/tensor/rewriting/blockwise.py:4
- The imported
vectorize_node
is no longer used in this file. Consider removing this import to clean up unused code.
from pytensor.graph.replace import vectorize_node
pytensor/tensor/rewriting/blockwise.py:5
- Neither
copy_stack_trace
norout2in
are used after the refactor. Removing these unused imports will improve code clarity.
from pytensor.graph.rewriting.basic import copy_stack_trace, out2in
Actually I suspect with my second commit the inplace optimizer can't accidentally introduce invalid graphs anymore. If that's the case we could do a single validate at the end of the rewrite. I didn't actually try to see what's so slow about validate. That's less priority than just the fix here |
6475c85
to
1438953
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Approving with the caveat that a lot of this is over my head somewhat.
The point is that when we have an elemwise Op (or blockwise it turns out) that can be rewritten to inplace, we were previously missing a validation check? So in cases where only a subset of inputs could be destroyed, all inputs were being destroyed? But only if the graph had more than 500 nodes, because that triggered a validation skip to save compute?
Very impressed you tracked this down!
} | ||
large_graph = len(fgraph.apply_nodes) > 500 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
500 is a bit of a magic number, should we be reading it from pytensorrc?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe but I would leave that for another time
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Opened issue #1515
root_destroyer = fgraph.destroy_handler.root_destroyer | ||
|
||
update_mapping = fgraph.update_mapping or {} | ||
op_updates: dict[TensorVariable, TensorVariable] = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When are these inline typehints useful? I've tried to throw them into the mypy volcano before to make it happy, but it never worked.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You won't be getting mypy insights from me :)
Usually when there are two branches that assign different things, or the dict may be empty some of the times, then you have to tell it
) | ||
sorted_candidate_pairs = candidate_pairs | ||
if op_updates and (node_updates := set(node.outputs) & set_op_updates): | ||
# If the fgraph has updates, we try to prioritize in-placing on the pairs that correspond to the update |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I quite like the old (more verbose) comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel the variable names direct/indirect/other
updates are pretty self-explanatory?
def test_InplaceElemwiseOptimizer_bug(): | ||
# Regression test for https://github.com/pymc-devs/pytensor/issues/1420 | ||
|
||
# This graph fails if InplaceElemwiseOptimizer were to try to skip `fgraph.validate` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you mean fgraph.replace_all_validate
in this comment?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
replace_all_validate is replace + validate. doesn't matter what method you use, the problem was skipping the validate part
@@ -92,8 +71,7 @@ def apply(self, fgraph): | |||
(x + y) * (x * y) -> (x += y) *= (x * y) or (x + y) *= (x *= y) | |||
|
|||
""" | |||
# We should not validate too often as this takes too much time to | |||
# execute! | |||
# We should not validate too often as this takes too much time to execute! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this something egglog could help with?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not really
view_of_protected = False | ||
root = inp | ||
try: | ||
while True: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand this. If root
is in protected_inputs
, how do we ever exit this loop? It doesn't seem like you're looping over anything
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
view_i[root] for a root variable is None
, so next time it tries view_i[None]
and fails with a KeyError.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
70% confidence
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
view_i doesn't have entries if a variable is already a root, checked now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh it's reassigning the root
variable inside the while
loop. That's what I was missing.
if inplace_node.op.destroy_map == inplace_pattern: | ||
replacements = tuple(zip(node.outputs, inplace_node.outputs)) | ||
try: | ||
fgraph.replace_all_validate( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I understand well, this is the key bugfix here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, but actually I think I solved the problem by being more clever about the candidates
The problem was Elemwise sometimes tried to inplace on multiple input/output pairs, and those multiple inputs could be alias to each other. If it had called |
1438953
to
eca7bdc
Compare
Closes #1420
There was a performance-related hack in the ElemwiseInplaceOptimizer where it tried to avoid validating the graph after replacing each node.
This is a bad idea because it may revert valid rewrites that happen to be caught in the same "check" window as an invalid rewrite. More importantly since we started inplacing on multi-output Elemwise, it could trigger an exception when trying to call
has_destroyers
on subsequent nodes, due to a previous invalid replacement.It was hard to track down this issue because the special behavior was only triggered once a graph had more than 500 nodes.
After the refactor I noticed that most of the logic of the rewrite can be shared with Blockwise, so I went ahead and refactored it, which also closes #1457
📚 Documentation preview 📚: https://pytensor--1494.org.readthedocs.build/en/1494/