Skip to content

Commit 1af9231

Browse files
eellisonpytorchmergebot
authored andcommitted
Replace IndexingDiv with FloorDiv in test_torchinductor (#93003)
Holdover from #92878 Pull Request resolved: #93003 Approved by: https://github.com/ngimel
1 parent 1f55f3b commit 1af9231

File tree

1 file changed

+12
-11
lines changed

1 file changed

+12
-11
lines changed

test/inductor/test_torchinductor.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
from torch._inductor.codegen.cpp import cexpr, CppOverrides, CppVecOverrides
4949
from torch._inductor.codegen.triton import texpr
5050
from torch._inductor.compile_fx import compile_fx, complex_memory_overlap
51-
from torch._inductor.ir import IndexingDiv, ModularIndexing
51+
from torch._inductor.ir import ModularIndexing
5252
from torch._inductor.overrides import (
5353
linear_permute_fusion,
5454
linear_transpose,
@@ -60,6 +60,7 @@
6060
)
6161
from torch._inductor.sizevars import SizeVarAllocator
6262
from torch._inductor.utils import has_torchvision_roi_align, timed
63+
from torch.fx.experimental.symbolic_shapes import FloorDiv
6364

6465
# This will only pass on pytorch builds newer than roughly 5/15/2022
6566
assert get_decompositions([torch.ops.aten.trace])
@@ -552,24 +553,24 @@ def test_indexing_simplification(self):
552553
self.assertEqual(
553554
sizevars.simplify_with_ranges(expr, var_ranges), i1 + 128 * i2 + 64 * r3
554555
)
555-
# if there are negative terms in ModularIndexing base, we cannot replace it with IndexingDiv
556+
# if there are negative terms in ModularIndexing base, we cannot replace it with FloorDiv
556557
expr = ModularIndexing(i1 - 15, 1, 64)
557558
self.assertEqual(
558559
sizevars.simplify_with_ranges(expr, var_ranges),
559560
ModularIndexing(i1 - 15, 1, 64),
560561
)
561562
# small terms should be kept if the rest is not guaranteed to be divisible
562563
self.assertEqual(
563-
sizevars.simplify_with_ranges(IndexingDiv(r3 + i2 + i1, 32), var_ranges),
564-
IndexingDiv(r3 + i2 + i1, 32),
564+
sizevars.simplify_with_ranges(FloorDiv(r3 + i2 + i1, 32), var_ranges),
565+
FloorDiv(r3 + i2 + i1, 32),
565566
)
566567

567568
expr = ModularIndexing(2 * i2 + r3, 1, 64)
568569
# modular indexing is removed if base is smaller than modulo
569570
self.assertEqual(sizevars.simplify_with_ranges(expr, var_ranges), 2 * i2 + r3)
570571

571572
# check the same thing but with symbolic divisor
572-
self.assertEqual(IndexingDiv(r3 * i0, r3), i0)
573+
self.assertEqual(FloorDiv(r3 * i0, r3), i0)
573574
self.assertEqual(ModularIndexing(r3 * i0, r3, 10), ModularIndexing(i0, 1, 10))
574575

575576
# (10*i) % 10 is always zero and should get optimized away
@@ -597,7 +598,7 @@ def test_indexing_simplification(self):
597598

598599
# Constant fold from divisor into base
599600
self.assertEqual(ModularIndexing(i0 * 4, 2, 10), ModularIndexing(i0 * 2, 1, 10))
600-
self.assertEqual(IndexingDiv(i0 * 4, 2), i0 * 2)
601+
self.assertEqual(FloorDiv(i0 * 4, 2), i0 * 2)
601602

602603
# Nested modular indexing is correctly simplified
603604
var_ranges = {"i1": 13, "i2": 121}
@@ -607,7 +608,7 @@ def test_indexing_simplification(self):
607608
self.assertEqual(sizevars.simplify_with_ranges(expr, var_ranges), expr)
608609
var_ranges = {"i2": 784}
609610
expr = ModularIndexing(ModularIndexing(i2, 1, 28), 7, 4)
610-
expected = IndexingDiv(ModularIndexing(i2, 1, 28), 7)
611+
expected = FloorDiv(ModularIndexing(i2, 1, 28), 7)
611612
self.assertEqual(sizevars.simplify_with_ranges(expr, var_ranges), expected)
612613
expr = ModularIndexing(ModularIndexing(i2, 1, 28) + 1, 7, 4)
613614
self.assertEqual(sizevars.simplify_with_ranges(expr, var_ranges), expr)
@@ -654,8 +655,8 @@ def test_indexing_join(self):
654655
ModularIndexing(i0, 10, i1 * i2) + 10,
655656
)
656657

657-
# works for ModularIndexing + IndexingDiv
658-
expr5 = 197 * IndexingDiv(i0, 197) + ModularIndexing(i0, 1, 197)
658+
# works for ModularIndexing + FloorDiv
659+
expr5 = 197 * FloorDiv(i0, 197) + ModularIndexing(i0, 1, 197)
659660
simplified = sizevars.simplify_with_ranges(expr5, {})
660661
self.assertEqual(simplified, i0)
661662
self.assertEqual(expr5.subs({i0: 39485}), simplified.subs({i0: 39485}))
@@ -667,9 +668,9 @@ def test_indexing_join(self):
667668
)
668669

669670
# divisor != 1
670-
expr6 = 197 * IndexingDiv(i0, 197 * 3) + ModularIndexing(i0, 3, 197)
671+
expr6 = 197 * FloorDiv(i0, 197 * 3) + ModularIndexing(i0, 3, 197)
671672
simplified = sizevars.simplify_with_ranges(expr6, {})
672-
self.assertEqual(simplified, IndexingDiv(i0, 3))
673+
self.assertEqual(simplified, FloorDiv(i0, 3))
673674
self.assertEqual(expr6.subs({i0: 39485}), simplified.subs({i0: 39485}))
674675

675676

0 commit comments

Comments
 (0)