1+ # RUN: %PYTHON %s | FileCheck %s
2+
3+ from mlir .ir import *
4+ from mlir .dialects import shard
5+ from mlir .dialects import func
6+
7+
8+ def constructAndPrintInModule (f ):
9+ print ("\n TEST:" , f .__name__ )
10+ with Context (), Location .unknown ():
11+ module = Module .create ()
12+ with InsertionPoint (module .body ):
13+ f ()
14+ print (module )
15+ return f
16+
17+
18+ # CHECK-LABEL: TEST: testShardGrid
19+ @constructAndPrintInModule
20+ def testShardGrid ():
21+ # Test creating shard grids with different shapes
22+ grid2d = shard .GridOp ("grid_2d" , [2 , 2 ])
23+ grid1d = shard .GridOp ("grid_1d" , [4 ])
24+ grid_dynamic = shard .GridOp ("grid_dynamic" , [2 , - 1 ]) # -1 for dynamic dimension
25+
26+ # CHECK: shard.grid @grid_2d(shape = 2x2)
27+ # CHECK: shard.grid @grid_1d(shape = 4)
28+ # CHECK: shard.grid @grid_dynamic(shape = 2x?)
29+
30+
31+ # CHECK-LABEL: TEST: testCollectiveOperations
32+ @constructAndPrintInModule
33+ def testCollectiveOperations ():
34+ # Create grid and types
35+ grid = shard .GridOp ("grid_2x2" , [2 , 2 ])
36+ i32 = IntegerType .get_signless (32 )
37+ input_type = RankedTensorType .get ([4 , 2 ], i32 )
38+ gather_result_type = RankedTensorType .get ([4 , 4 ], i32 )
39+
40+ # Create a function to hold the operations
41+ func_type = FunctionType .get ([input_type ], [input_type ])
42+ test_func = func .FuncOp ("test_collectives" , func_type )
43+
44+ with InsertionPoint (test_func .add_entry_block ()):
45+ arg = test_func .entry_block .arguments [0 ]
46+
47+ gather_op = shard .AllGatherOp (
48+ input = arg ,
49+ grid = FlatSymbolRefAttr .get ("grid_2x2" ),
50+ grid_axes = ArrayAttr .get ([IntegerAttr .get (i32 , 1 )]),
51+ gather_axis = IntegerAttr .get (i32 , 1 ),
52+ result = gather_result_type ,
53+ )
54+
55+ reduce_op = shard .AllReduceOp (
56+ input = arg ,
57+ grid = FlatSymbolRefAttr .get ("grid_2x2" ),
58+ reduction = shard .ReductionKind .Sum ,
59+ result = input_type ,
60+ )
61+
62+ func .ReturnOp ([reduce_op ])
63+
64+ # CHECK: shard.grid @grid_2x2(shape = 2x2)
65+ # CHECK: func @test_collectives(%{{.*}}: tensor<4x2xi32>) -> tensor<4x2xi32>
66+ # CHECK: %{{.*}} = shard.all_gather %{{.*}} on @grid_2x2 grid_axes = [1] gather_axis = 1 : tensor<4x2xi32> -> tensor<4x4xi32>
67+ # CHECK: %{{.*}} = shard.all_reduce %{{.*}} on @grid_2x2 reduction = sum : tensor<4x2xi32> -> tensor<4x2xi32>
0 commit comments