Skip to content

Implement new polygeist.memrefdiff operation #7431

@victor-eds

Description

@victor-eds

Is your feature request related to a problem? Please describe

cgeist currently implements pointer difference by lowering arguments to LLVM pointers, casting to int and operating there. This leads to some scenarios in which optimizations are missed.

Describe the solution you would like

Implementing a polygeist.memrefdiff to simplify detection of scenarios in which we might perform further optimizations, e.g., loop unrolling. In the following example:

void foo(int *a, int *b) {
  size_t size = 1;
  for (size_t i = 0; i < size; ++i)
    a[i] = b[i];
}

We can actually detect that the body of the loop is run a single time, so we can generate the following MLIR code:

func.func @foo(%arg0: memref<?xi32>, %arg1: memref<?xi32>) attributes {llvm.linkage = #llvm.linkage<external>} {                                                                                                                                             
  %0 = affine.load %arg1[0] : memref<?xi32>                                                                                                                                                                                                                  
  affine.store %0, %arg0[0] : memref<?xi32>                                                                                                                                                                                                                  
  return                                                                                                                                                                                                                                                     
}

If we introduce this slight change:

void foo(int *a, int *b) {
  int *end = a + 1;
  size_t size = end - a;
  for (size_t i = 0; i < size; ++i)
    a[i] = b[i];
}

We no longer detect that the body is executed only once, so we generate the following instead:

func.func @foo(%arg0: memref<?xi32>, %arg1: memref<?xi32>) attributes {llvm.linkage = #llvm.linkage<external>} {                                                                                                                                             
  %c0 = arith.constant 0 : index                                                                                                                                                                                                                             
  %c1 = arith.constant 1 : index                                                                                                                                                                                                                             
  %c4_i64 = arith.constant 4 : i64                                                                                                                                                                                                                           
  %0 = "polygeist.memref2pointer"(%arg0) : (memref<?xi32>) -> !llvm.ptr<i32>                                                                                                                                                                                 
  %1 = llvm.getelementptr %0[1] : (!llvm.ptr<i32>) -> !llvm.ptr<i32>                                                                                                                                                                                         
  %2 = llvm.ptrtoint %1 : !llvm.ptr<i32> to i64                                                                                                                                                                                                              
  %3 = llvm.ptrtoint %0 : !llvm.ptr<i32> to i64                                                                                                                                                                                                              
  %4 = arith.subi %2, %3 : i64                                                                                                                                                                                                                               
  %5 = arith.divsi %4, %c4_i64 : i64                                                                                                                                                                                                                         
  %6 = arith.index_cast %5 : i64 to index                                                                                                                                                                                                                    
  scf.for %arg2 = %c0 to %6 step %c1 {                                                                                                                                                                                                                       
    %7 = memref.load %arg1[%arg2] : memref<?xi32>                                                                                                                                                                                                            
    memref.store %7, %arg0[%arg2] : memref<?xi32>                                                                                                                                                                                                            
  }                                                                                                                                                                                                                                                          
  return                                                                                                                                                                                                                                                     
}

This could be easily fixed by introducing the aforementioned operation and introducing a canonicalization rule s.t.:

polygeist.memrefdiff(polygeist.subindex(%ptr, %offset), %ptr) -> %offset

The instruction should be inserted at the point mentioned in this comment.

Lowering this new operation would generate the code that we're currently generating.

Of course, this would also work on a higher number of iterations, but, as we're currently not performing loop unroll, effects with a single iteration were more evident.

Describe alternatives you have considered

Mimic LLVM code detecting cases as the one above: would imply working at a lower level. MLIR provides higher level mechanisms to work with, so we should benefit from them.

Metadata

Metadata

Assignees

Labels

enhancementNew feature or requestsycl-mlirPull requests or issues for sycl-mlir branch

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions