4848 from torch ._inductor .codegen .cpp import cexpr , CppOverrides , CppVecOverrides
4949 from torch ._inductor .codegen .triton import texpr
5050 from torch ._inductor .compile_fx import compile_fx , complex_memory_overlap
51- from torch ._inductor .ir import IndexingDiv , ModularIndexing
51+ from torch ._inductor .ir import ModularIndexing
5252 from torch ._inductor .overrides import (
5353 linear_permute_fusion ,
5454 linear_transpose ,
6060 )
6161 from torch ._inductor .sizevars import SizeVarAllocator
6262 from torch ._inductor .utils import has_torchvision_roi_align , timed
63+ from torch .fx .experimental .symbolic_shapes import FloorDiv
6364
6465 # This will only pass on pytorch builds newer than roughly 5/15/2022
6566 assert get_decompositions ([torch .ops .aten .trace ])
@@ -552,24 +553,24 @@ def test_indexing_simplification(self):
552553 self .assertEqual (
553554 sizevars .simplify_with_ranges (expr , var_ranges ), i1 + 128 * i2 + 64 * r3
554555 )
555- # if there are negative terms in ModularIndexing base, we cannot replace it with IndexingDiv
556+ # if there are negative terms in ModularIndexing base, we cannot replace it with FloorDiv
556557 expr = ModularIndexing (i1 - 15 , 1 , 64 )
557558 self .assertEqual (
558559 sizevars .simplify_with_ranges (expr , var_ranges ),
559560 ModularIndexing (i1 - 15 , 1 , 64 ),
560561 )
561562 # small terms should be kept if the rest is not guaranteed to be divisible
562563 self .assertEqual (
563- sizevars .simplify_with_ranges (IndexingDiv (r3 + i2 + i1 , 32 ), var_ranges ),
564- IndexingDiv (r3 + i2 + i1 , 32 ),
564+ sizevars .simplify_with_ranges (FloorDiv (r3 + i2 + i1 , 32 ), var_ranges ),
565+ FloorDiv (r3 + i2 + i1 , 32 ),
565566 )
566567
567568 expr = ModularIndexing (2 * i2 + r3 , 1 , 64 )
568569 # modular indexing is removed if base is smaller than modulo
569570 self .assertEqual (sizevars .simplify_with_ranges (expr , var_ranges ), 2 * i2 + r3 )
570571
571572 # check the same thing but with symbolic divisor
572- self .assertEqual (IndexingDiv (r3 * i0 , r3 ), i0 )
573+ self .assertEqual (FloorDiv (r3 * i0 , r3 ), i0 )
573574 self .assertEqual (ModularIndexing (r3 * i0 , r3 , 10 ), ModularIndexing (i0 , 1 , 10 ))
574575
575576 # (10*i) % 10 is always zero and should get optimized away
@@ -597,7 +598,7 @@ def test_indexing_simplification(self):
597598
598599 # Constant fold from divisor into base
599600 self .assertEqual (ModularIndexing (i0 * 4 , 2 , 10 ), ModularIndexing (i0 * 2 , 1 , 10 ))
600- self .assertEqual (IndexingDiv (i0 * 4 , 2 ), i0 * 2 )
601+ self .assertEqual (FloorDiv (i0 * 4 , 2 ), i0 * 2 )
601602
602603 # Nested modular indexing is correctly simplified
603604 var_ranges = {"i1" : 13 , "i2" : 121 }
@@ -607,7 +608,7 @@ def test_indexing_simplification(self):
607608 self .assertEqual (sizevars .simplify_with_ranges (expr , var_ranges ), expr )
608609 var_ranges = {"i2" : 784 }
609610 expr = ModularIndexing (ModularIndexing (i2 , 1 , 28 ), 7 , 4 )
610- expected = IndexingDiv (ModularIndexing (i2 , 1 , 28 ), 7 )
611+ expected = FloorDiv (ModularIndexing (i2 , 1 , 28 ), 7 )
611612 self .assertEqual (sizevars .simplify_with_ranges (expr , var_ranges ), expected )
612613 expr = ModularIndexing (ModularIndexing (i2 , 1 , 28 ) + 1 , 7 , 4 )
613614 self .assertEqual (sizevars .simplify_with_ranges (expr , var_ranges ), expr )
@@ -654,8 +655,8 @@ def test_indexing_join(self):
654655 ModularIndexing (i0 , 10 , i1 * i2 ) + 10 ,
655656 )
656657
657- # works for ModularIndexing + IndexingDiv
658- expr5 = 197 * IndexingDiv (i0 , 197 ) + ModularIndexing (i0 , 1 , 197 )
658+ # works for ModularIndexing + FloorDiv
659+ expr5 = 197 * FloorDiv (i0 , 197 ) + ModularIndexing (i0 , 1 , 197 )
659660 simplified = sizevars .simplify_with_ranges (expr5 , {})
660661 self .assertEqual (simplified , i0 )
661662 self .assertEqual (expr5 .subs ({i0 : 39485 }), simplified .subs ({i0 : 39485 }))
@@ -667,9 +668,9 @@ def test_indexing_join(self):
667668 )
668669
669670 # divisor != 1
670- expr6 = 197 * IndexingDiv (i0 , 197 * 3 ) + ModularIndexing (i0 , 3 , 197 )
671+ expr6 = 197 * FloorDiv (i0 , 197 * 3 ) + ModularIndexing (i0 , 3 , 197 )
671672 simplified = sizevars .simplify_with_ranges (expr6 , {})
672- self .assertEqual (simplified , IndexingDiv (i0 , 3 ))
673+ self .assertEqual (simplified , FloorDiv (i0 , 3 ))
673674 self .assertEqual (expr6 .subs ({i0 : 39485 }), simplified .subs ({i0 : 39485 }))
674675
675676
0 commit comments