Skip to content

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Jun 23, 2025

Closes #1420

There was a performance-related hack in the ElemwiseInplaceOptimizer where it tried to avoid validating the graph after replacing each node.

This is a bad idea because it may revert valid rewrites that happen to be caught in the same "check" window as an invalid rewrite. More importantly since we started inplacing on multi-output Elemwise, it could trigger an exception when trying to call has_destroyers on subsequent nodes, due to a previous invalid replacement.

It was hard to track down this issue because the special behavior was only triggered once a graph had more than 500 nodes.

After the refactor I noticed that most of the logic of the rewrite can be shared with Blockwise, so I went ahead and refactored it, which also closes #1457


📚 Documentation preview 📚: https://pytensor--1494.org.readthedocs.build/en/1494/

Copy link

codecov bot commented Jun 23, 2025

Codecov Report

Attention: Patch coverage is 90.53254% with 16 lines in your changes missing coverage. Please review.

Project coverage is 82.01%. Comparing base (236e50d) to head (eca7bdc).
Report is 17 commits behind head on main.

Files with missing lines Patch % Lines
pytensor/tensor/rewriting/elemwise.py 92.72% 5 Missing and 3 partials ⚠️
pytensor/graph/destroyhandler.py 86.66% 3 Missing and 1 partial ⚠️
pytensor/tensor/rewriting/blockwise.py 86.20% 2 Missing and 2 partials ⚠️

❌ Your patch status has failed because the patch coverage (90.53%) is below the target coverage (100.00%). You can increase the patch coverage or adjust the target coverage.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #1494      +/-   ##
==========================================
+ Coverage   81.98%   82.01%   +0.03%     
==========================================
  Files         231      231              
  Lines       52192    52299     +107     
  Branches     9185     9207      +22     
==========================================
+ Hits        42790    42894     +104     
- Misses       7094     7095       +1     
- Partials     2308     2310       +2     
Files with missing lines Coverage Δ
pytensor/graph/destroyhandler.py 71.85% <86.66%> (+2.40%) ⬆️
pytensor/tensor/rewriting/blockwise.py 96.07% <86.20%> (-1.28%) ⬇️
pytensor/tensor/rewriting/elemwise.py 92.72% <92.72%> (+0.94%) ⬆️

... and 13 files with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@ricardoV94 ricardoV94 force-pushed the has_destroyers_bug branch 2 times, most recently from 65f2e35 to 6475c85 Compare June 24, 2025 00:11
@ricardoV94 ricardoV94 marked this pull request as ready for review June 24, 2025 09:05
Copy link

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR ensures validation after each in-place rewrite in the Elemwise optimizer, refactors shared logic for Blockwise in-place rewrites, and adds regression tests for both cases.

  • Enforce full graph validation between consecutive Elemwise in-place replacements
  • Refactor Blockwise in-place optimizer to extend the shared InplaceGraphOptimizer
  • Add tests for regression in Elemwise and partial in-place behavior in Blockwise

Reviewed Changes

Copilot reviewed 4 out of 5 changed files in this pull request and generated no comments.

File Description
tests/tensor/test_blockwise.py Add test_partial_inplace to cover Blockwise in-place behavior
tests/tensor/rewriting/test_elemwise.py Add test_InplaceElemwiseOptimizer_bug regression test for #1420
pytensor/tensor/rewriting/blockwise.py Refactor Blockwise in-place rewrite using InplaceGraphOptimizer
pytensor/graph/destroyhandler.py Introduce inplace_candidates helper for unified candidate logic
Comments suppressed due to low confidence (4)

tests/tensor/rewriting/test_elemwise.py:1535

  • [nitpick] Test function names should use snake_case. Consider renaming this to test_inplace_elemwise_optimizer_bug for consistency with pytest conventions.
def test_InplaceElemwiseOptimizer_bug():

tests/tensor/rewriting/test_elemwise.py:1544

  • The test references Elemwise but there is no corresponding import. Please add from pytensor.tensor.rewriting.elemwise import Elemwise at the top of the file.
    out1, out2 = Elemwise(ps.Composite([z1, z2], [z1 + z2, z2 - z1]))(z[1:], z[:-1])

pytensor/tensor/rewriting/blockwise.py:4

  • The imported vectorize_node is no longer used in this file. Consider removing this import to clean up unused code.
from pytensor.graph.replace import vectorize_node

pytensor/tensor/rewriting/blockwise.py:5

  • Neither copy_stack_trace nor out2in are used after the refactor. Removing these unused imports will improve code clarity.
from pytensor.graph.rewriting.basic import copy_stack_trace, out2in

@ricardoV94
Copy link
Member Author

ricardoV94 commented Jun 24, 2025

Actually I suspect with my second commit the inplace optimizer can't accidentally introduce invalid graphs anymore. If that's the case we could do a single validate at the end of the rewrite. I didn't actually try to see what's so slow about validate.

That's less priority than just the fix here

@ricardoV94 ricardoV94 force-pushed the has_destroyers_bug branch from 6475c85 to 1438953 Compare June 24, 2025 12:17
Copy link
Member

@jessegrabowski jessegrabowski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approving with the caveat that a lot of this is over my head somewhat.

The point is that when we have an elemwise Op (or blockwise it turns out) that can be rewritten to inplace, we were previously missing a validation check? So in cases where only a subset of inputs could be destroyed, all inputs were being destroyed? But only if the graph had more than 500 nodes, because that triggered a validation skip to save compute?

Very impressed you tracked this down!

}
large_graph = len(fgraph.apply_nodes) > 500
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

500 is a bit of a magic number, should we be reading it from pytensorrc?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe but I would leave that for another time

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Opened issue #1515

root_destroyer = fgraph.destroy_handler.root_destroyer

update_mapping = fgraph.update_mapping or {}
op_updates: dict[TensorVariable, TensorVariable] = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When are these inline typehints useful? I've tried to throw them into the mypy volcano before to make it happy, but it never worked.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You won't be getting mypy insights from me :)

Usually when there are two branches that assign different things, or the dict may be empty some of the times, then you have to tell it

)
sorted_candidate_pairs = candidate_pairs
if op_updates and (node_updates := set(node.outputs) & set_op_updates):
# If the fgraph has updates, we try to prioritize in-placing on the pairs that correspond to the update
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I quite like the old (more verbose) comments

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel the variable names direct/indirect/other updates are pretty self-explanatory?

def test_InplaceElemwiseOptimizer_bug():
# Regression test for https://github.com/pymc-devs/pytensor/issues/1420

# This graph fails if InplaceElemwiseOptimizer were to try to skip `fgraph.validate`
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean fgraph.replace_all_validate in this comment?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

replace_all_validate is replace + validate. doesn't matter what method you use, the problem was skipping the validate part

@@ -92,8 +71,7 @@ def apply(self, fgraph):
(x + y) * (x * y) -> (x += y) *= (x * y) or (x + y) *= (x *= y)

"""
# We should not validate too often as this takes too much time to
# execute!
# We should not validate too often as this takes too much time to execute!
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this something egglog could help with?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not really

view_of_protected = False
root = inp
try:
while True:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand this. If root is in protected_inputs, how do we ever exit this loop? It doesn't seem like you're looping over anything

Copy link
Member Author

@ricardoV94 ricardoV94 Jul 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

view_i[root] for a root variable is None, so next time it tries view_i[None] and fails with a KeyError.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

70% confidence

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

view_i doesn't have entries if a variable is already a root, checked now

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh it's reassigning the root variable inside the while loop. That's what I was missing.

if inplace_node.op.destroy_map == inplace_pattern:
replacements = tuple(zip(node.outputs, inplace_node.outputs))
try:
fgraph.replace_all_validate(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand well, this is the key bugfix here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but actually I think I solved the problem by being more clever about the candidates

@ricardoV94
Copy link
Member Author

The point is that when we have an elemwise Op (or blockwise it turns out) that can be rewritten to inplace, we were previously missing a validation check? So in cases where only a subset of inputs could be destroyed, all inputs were being destroyed? But only if the graph had more than 500 nodes, because that triggered a validation skip to save compute?

The problem was Elemwise sometimes tried to inplace on multiple input/output pairs, and those multiple inputs could be alias to each other. If it had called validate immediately it would have opted out but it didn't. Then the next node would call has_destroyers which would raise an error because it saw the same variable was destroyed twice (by the previous rewrite).

@ricardoV94 ricardoV94 force-pushed the has_destroyers_bug branch from 1438953 to eca7bdc Compare July 2, 2025 09:01
@ricardoV94 ricardoV94 merged commit 45a33ad into pymc-devs:main Jul 2, 2025
73 of 74 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
2 participants