diff --git a/mlir/include/mlir/Dialect/OpenACC/Analysis/OpenACCSupport.h b/mlir/include/mlir/Dialect/OpenACC/Analysis/OpenACCSupport.h new file mode 100644 index 0000000000000..0833462ea0509 --- /dev/null +++ b/mlir/include/mlir/Dialect/OpenACC/Analysis/OpenACCSupport.h @@ -0,0 +1,135 @@ +//===- OpenACCSupport.h - OpenACC Support Interface -------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the OpenACCSupport analysis interface, which provides +// extensible support for OpenACC passes. Custom implementations +// can be registered to provide pipeline and dialect-specific information +// that cannot be adequately expressed through type or operation interfaces +// alone. +// +// Usage Pattern: +// ============== +// +// A pass that needs this functionality should call +// getAnalysis(), which will provide either: +// - A cached version if previously initialized, OR +// - A default implementation if not previously initialized +// +// This analysis is never invalidated (isInvalidated returns false), so it only +// needs to be initialized once and will persist throughout the pass pipeline. +// +// Registering a Custom Implementation: +// ===================================== +// +// If a custom implementation is needed, create a pass that runs BEFORE the pass +// that needs the analysis. In this setup pass, use +// getAnalysis() followed by setImplementation() to register +// your custom implementation. The custom implementation will need to provide +// implementation for all methods defined in the `OpenACCSupportTraits::Concept` +// class. +// +// Example: +// void MySetupPass::runOnOperation() { +// OpenACCSupport &support = getAnalysis(); +// support.setImplementation(MyCustomImpl()); +// } +// +// void MyAnalysisConsumerPass::runOnOperation() { +// OpenACCSupport &support = getAnalysis(); +// std::string name = support.getVariableName(someValue); +// // ... use the analysis results +// } +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_OPENACC_ANALYSIS_OPENACCSUPPORT_H +#define MLIR_DIALECT_OPENACC_ANALYSIS_OPENACCSUPPORT_H + +#include "mlir/IR/Value.h" +#include "mlir/Pass/AnalysisManager.h" +#include +#include + +namespace mlir { +namespace acc { + +namespace detail { +/// This class contains internal trait classes used by OpenACCSupport. +/// It follows the Concept-Model pattern used throughout MLIR (e.g., in +/// AliasAnalysis and interface definitions). +struct OpenACCSupportTraits { + class Concept { + public: + virtual ~Concept() = default; + + /// Get the variable name for a given MLIR value. + virtual std::string getVariableName(Value v) = 0; + }; + + /// This class wraps a concrete OpenACCSupport implementation and forwards + /// interface calls to it. This provides type erasure, allowing different + /// implementation types to be used interchangeably without inheritance. + template + class Model final : public Concept { + public: + explicit Model(ImplT &&impl) : impl(std::forward(impl)) {} + ~Model() override = default; + + std::string getVariableName(Value v) final { + return impl.getVariableName(v); + } + + private: + ImplT impl; + }; +}; +} // namespace detail + +//===----------------------------------------------------------------------===// +// OpenACCSupport +//===----------------------------------------------------------------------===// + +class OpenACCSupport { + using Concept = detail::OpenACCSupportTraits::Concept; + template + using Model = detail::OpenACCSupportTraits::Model; + +public: + OpenACCSupport() = default; + OpenACCSupport(Operation *op) {} + + /// Register a custom OpenACCSupport implementation. Only one implementation + /// can be registered at a time; calling this replaces any existing + /// implementation. + template + void setImplementation(AnalysisT &&analysis) { + impl = + std::make_unique>(std::forward(analysis)); + } + + /// Get the variable name for a given value. + /// + /// \param v The MLIR value to get the variable name for. + /// \return The variable name, or an empty string if unavailable. + std::string getVariableName(Value v); + + /// Signal that this analysis should always be preserved so that + /// underlying implementation registration is not lost. + bool isInvalidated(const AnalysisManager::PreservedAnalyses &pa) { + return false; + } + +private: + /// The registered custom implementation (if any). + std::unique_ptr impl; +}; + +} // namespace acc +} // namespace mlir + +#endif // MLIR_DIALECT_OPENACC_ANALYSIS_OPENACCSUPPORT_H diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCUtils.h b/mlir/include/mlir/Dialect/OpenACC/OpenACCUtils.h index 378f4348f2cf1..0ee88c6f47b67 100644 --- a/mlir/include/mlir/Dialect/OpenACC/OpenACCUtils.h +++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCUtils.h @@ -38,6 +38,11 @@ std::optional getDefaultAttr(mlir::Operation *op); /// Get the type category of an OpenACC variable. mlir::acc::VariableTypeCategory getTypeCategory(mlir::Value var); +/// Attempts to extract the variable name from a value by walking through +/// view-like operations until an `acc.var_name` attribute is found. Returns +/// empty string if no name is found. +std::string getVariableName(mlir::Value v); + } // namespace acc } // namespace mlir diff --git a/mlir/lib/Dialect/OpenACC/Analysis/CMakeLists.txt b/mlir/lib/Dialect/OpenACC/Analysis/CMakeLists.txt new file mode 100644 index 0000000000000..f305068e1b3bc --- /dev/null +++ b/mlir/lib/Dialect/OpenACC/Analysis/CMakeLists.txt @@ -0,0 +1,13 @@ +add_mlir_dialect_library(MLIROpenACCAnalysis + OpenACCSupport.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/OpenACC + + LINK_LIBS PUBLIC + MLIRIR + MLIROpenACCDialect + MLIROpenACCUtils + MLIRSupport +) + diff --git a/mlir/lib/Dialect/OpenACC/Analysis/OpenACCSupport.cpp b/mlir/lib/Dialect/OpenACC/Analysis/OpenACCSupport.cpp new file mode 100644 index 0000000000000..f6b4534794eaf --- /dev/null +++ b/mlir/lib/Dialect/OpenACC/Analysis/OpenACCSupport.cpp @@ -0,0 +1,26 @@ +//===- OpenACCSupport.cpp - OpenACCSupport Implementation -----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the OpenACCSupport analysis interface. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/OpenACC/Analysis/OpenACCSupport.h" +#include "mlir/Dialect/OpenACC/OpenACCUtils.h" + +namespace mlir { +namespace acc { + +std::string OpenACCSupport::getVariableName(Value v) { + if (impl) + return impl->getVariableName(v); + return acc::getVariableName(v); +} + +} // namespace acc +} // namespace mlir diff --git a/mlir/lib/Dialect/OpenACC/CMakeLists.txt b/mlir/lib/Dialect/OpenACC/CMakeLists.txt index 7117520599fa6..e8a916e824d71 100644 --- a/mlir/lib/Dialect/OpenACC/CMakeLists.txt +++ b/mlir/lib/Dialect/OpenACC/CMakeLists.txt @@ -1,3 +1,4 @@ +add_subdirectory(Analysis) add_subdirectory(IR) add_subdirectory(Utils) add_subdirectory(Transforms) diff --git a/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp b/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp index 12233254f3fb4..89adda82646e6 100644 --- a/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp +++ b/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp @@ -9,6 +9,7 @@ #include "mlir/Dialect/OpenACC/OpenACCUtils.h" #include "mlir/Dialect/OpenACC/OpenACC.h" +#include "mlir/Interfaces/ViewLikeInterface.h" #include "llvm/ADT/TypeSwitch.h" mlir::Operation *mlir::acc::getEnclosingComputeOp(mlir::Region ®ion) { @@ -78,3 +79,30 @@ mlir::acc::VariableTypeCategory mlir::acc::getTypeCategory(mlir::Value var) { pointerLikeTy.getElementType()); return typeCategory; } + +std::string mlir::acc::getVariableName(mlir::Value v) { + Value current = v; + + // Walk through view operations until a name is found or can't go further + while (Operation *definingOp = current.getDefiningOp()) { + // Check for `acc.var_name` attribute + if (auto varNameAttr = + definingOp->getAttrOfType(getVarNameAttrName())) + return varNameAttr.getName().str(); + + // If it is a data entry operation, get name via getVarName + if (isa(definingOp)) + if (auto name = acc::getVarName(definingOp)) + return name->str(); + + // If it's a view operation, continue to the source + if (auto viewOp = dyn_cast(definingOp)) { + current = viewOp.getViewSource(); + continue; + } + + break; + } + + return ""; +} diff --git a/mlir/test/Dialect/OpenACC/support-analysis-varname.mlir b/mlir/test/Dialect/OpenACC/support-analysis-varname.mlir new file mode 100644 index 0000000000000..af52befb156e4 --- /dev/null +++ b/mlir/test/Dialect/OpenACC/support-analysis-varname.mlir @@ -0,0 +1,88 @@ +// RUN: mlir-opt %s -split-input-file -test-acc-support | FileCheck %s + +// Test with direct variable names +func.func @test_direct_var_name() { + // Create a memref with acc.var_name attribute + %0 = memref.alloca() {acc.var_name = #acc.var_name<"my_variable">} : memref<10xi32> + + %1 = memref.cast %0 {test.var_name} : memref<10xi32> to memref<10xi32> + + // CHECK: op=%{{.*}} = memref.cast %{{.*}} {test.var_name} : memref<10xi32> to memref<10xi32> + // CHECK-NEXT: getVariableName="my_variable" + + return +} + +// ----- + +// Test through memref.cast +func.func @test_through_cast() { + // Create a 5x2 memref with acc.var_name attribute + %0 = memref.alloca() {acc.var_name = #acc.var_name<"casted_variable">} : memref<5x2xi32> + + // Cast to dynamic dimensions + %1 = memref.cast %0 : memref<5x2xi32> to memref + + // Mark with test attribute - should find name through cast + %2 = memref.cast %1 {test.var_name} : memref to memref<5x2xi32> + + // CHECK: op=%{{.*}} = memref.cast %{{.*}} {test.var_name} : memref to memref<5x2xi32> + // CHECK-NEXT: getVariableName="casted_variable" + + return +} + +// ----- + +// Test with no variable name +func.func @test_no_var_name() { + // Create a memref without acc.var_name attribute + %0 = memref.alloca() : memref<10xi32> + + // Mark with test attribute - should find empty string + %1 = memref.cast %0 {test.var_name} : memref<10xi32> to memref<10xi32> + + // CHECK: op=%{{.*}} = memref.cast %{{.*}} {test.var_name} : memref<10xi32> to memref<10xi32> + // CHECK-NEXT: getVariableName="" + + return +} + +// ----- + +// Test through multiple casts +func.func @test_multiple_casts() { + // Create a memref with acc.var_name attribute + %0 = memref.alloca() {acc.var_name = #acc.var_name<"multi_cast">} : memref<10xi32> + + // Multiple casts + %1 = memref.cast %0 : memref<10xi32> to memref + %2 = memref.cast %1 : memref to memref<10xi32> + + // Mark with test attribute - should find name through multiple casts + %3 = memref.cast %2 {test.var_name} : memref<10xi32> to memref<10xi32> + + // CHECK: op=%{{.*}} = memref.cast %{{.*}} {test.var_name} : memref<10xi32> to memref<10xi32> + // CHECK-NEXT: getVariableName="multi_cast" + + return +} + +// ----- + +// Test with acc.copyin operation +func.func @test_copyin_name() { + // Create a memref + %0 = memref.alloca() : memref<10xf32> + + // Create an acc.copyin operation with a name + %1 = acc.copyin varPtr(%0 : memref<10xf32>) -> memref<10xf32> {name = "input_data"} + + // Mark with test attribute - should find name from copyin operation + %2 = memref.cast %1 {test.var_name} : memref<10xf32> to memref + + // CHECK: op=%{{.*}} = memref.cast %{{.*}} {test.var_name} : memref<10xf32> to memref + // CHECK-NEXT: getVariableName="input_data" + + return +} diff --git a/mlir/test/lib/Dialect/OpenACC/CMakeLists.txt b/mlir/test/lib/Dialect/OpenACC/CMakeLists.txt index 1e593389ec683..a54b642d4db42 100644 --- a/mlir/test/lib/Dialect/OpenACC/CMakeLists.txt +++ b/mlir/test/lib/Dialect/OpenACC/CMakeLists.txt @@ -2,6 +2,7 @@ add_mlir_library(MLIROpenACCTestPasses TestOpenACC.cpp TestPointerLikeTypeInterface.cpp TestRecipePopulate.cpp + TestOpenACCSupport.cpp EXCLUDE_FROM_LIBMLIR ) @@ -11,6 +12,7 @@ mlir_target_link_libraries(MLIROpenACCTestPasses PUBLIC MLIRFuncDialect MLIRMemRefDialect MLIROpenACCDialect + MLIROpenACCAnalysis MLIRPass MLIRSupport ) diff --git a/mlir/test/lib/Dialect/OpenACC/TestOpenACC.cpp b/mlir/test/lib/Dialect/OpenACC/TestOpenACC.cpp index bea21b9827f7e..e59d77777ee40 100644 --- a/mlir/test/lib/Dialect/OpenACC/TestOpenACC.cpp +++ b/mlir/test/lib/Dialect/OpenACC/TestOpenACC.cpp @@ -16,11 +16,13 @@ namespace test { // Forward declarations of individual test pass registration functions void registerTestPointerLikeTypeInterfacePass(); void registerTestRecipePopulatePass(); +void registerTestOpenACCSupportPass(); // Unified registration function for all OpenACC tests void registerTestOpenACC() { registerTestPointerLikeTypeInterfacePass(); registerTestRecipePopulatePass(); + registerTestOpenACCSupportPass(); } } // namespace test diff --git a/mlir/test/lib/Dialect/OpenACC/TestOpenACCSupport.cpp b/mlir/test/lib/Dialect/OpenACC/TestOpenACCSupport.cpp new file mode 100644 index 0000000000000..8bf984bdc2632 --- /dev/null +++ b/mlir/test/lib/Dialect/OpenACC/TestOpenACCSupport.cpp @@ -0,0 +1,73 @@ +//===- TestOpenACCSupport.cpp - Test OpenACCSupport Analysis -------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains test passes for testing the OpenACCSupport analysis. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/OpenACC/Analysis/OpenACCSupport.h" +#include "mlir/Dialect/OpenACC/OpenACC.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; +using namespace mlir::acc; + +namespace { + +struct TestOpenACCSupportPass + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestOpenACCSupportPass) + + StringRef getArgument() const override { return "test-acc-support"; } + + StringRef getDescription() const override { + return "Test OpenACCSupport analysis"; + } + + void runOnOperation() override; + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + registry.insert(); + } +}; + +void TestOpenACCSupportPass::runOnOperation() { + auto func = getOperation(); + + // Get the OpenACCSupport analysis + OpenACCSupport &support = getAnalysis(); + + // Walk through operations looking for test attributes + func.walk([&](Operation *op) { + // Check for test.var_name attribute. This is the marker used to identify + // the operations that need to be tested for getVariableName. + if (op->hasAttr("test.var_name")) { + // For each result of this operation, try to get the variable name + for (auto result : op->getResults()) { + std::string foundName = support.getVariableName(result); + llvm::outs() << "op=" << *op << "\n\tgetVariableName=\"" << foundName + << "\"\n"; + } + } + }); +} + +} // namespace + +namespace mlir { +namespace test { + +void registerTestOpenACCSupportPass() { + PassRegistration(); +} + +} // namespace test +} // namespace mlir diff --git a/mlir/unittests/Dialect/OpenACC/OpenACCUtilsTest.cpp b/mlir/unittests/Dialect/OpenACC/OpenACCUtilsTest.cpp index ab817b640edb3..3fbbcc90a67c9 100644 --- a/mlir/unittests/Dialect/OpenACC/OpenACCUtilsTest.cpp +++ b/mlir/unittests/Dialect/OpenACC/OpenACCUtilsTest.cpp @@ -410,3 +410,78 @@ TEST_F(OpenACCUtilsTest, getTypeCategoryArray) { VariableTypeCategory category = getTypeCategory(varPtr); EXPECT_EQ(category, VariableTypeCategory::array); } + +//===----------------------------------------------------------------------===// +// getVariableName Tests +//===----------------------------------------------------------------------===// + +TEST_F(OpenACCUtilsTest, getVariableNameDirect) { + // Create a memref with acc.var_name attribute + auto memrefTy = MemRefType::get({10}, b.getI32Type()); + OwningOpRef allocOp = + memref::AllocaOp::create(b, loc, memrefTy); + + // Set the acc.var_name attribute + auto varNameAttr = VarNameAttr::get(&context, "my_variable"); + allocOp.get()->setAttr(getVarNameAttrName(), varNameAttr); + + Value varPtr = allocOp->getResult(); + + // Test that getVariableName returns the variable name + std::string varName = getVariableName(varPtr); + EXPECT_EQ(varName, "my_variable"); +} + +TEST_F(OpenACCUtilsTest, getVariableNameThroughCast) { + // Create a 5x2 memref with acc.var_name attribute + auto memrefTy = MemRefType::get({5, 2}, b.getI32Type()); + OwningOpRef allocOp = + memref::AllocaOp::create(b, loc, memrefTy); + + // Set the acc.var_name attribute on the alloca + auto varNameAttr = VarNameAttr::get(&context, "casted_variable"); + allocOp.get()->setAttr(getVarNameAttrName(), varNameAttr); + + Value allocResult = allocOp->getResult(); + + // Create a memref.cast operation to a flattened 10-element array + auto castedMemrefTy = MemRefType::get({10}, b.getI32Type()); + OwningOpRef castOp = + memref::CastOp::create(b, loc, castedMemrefTy, allocResult); + + Value castedPtr = castOp->getResult(); + + // Test that getVariableName walks through the cast to find the variable name + std::string varName = getVariableName(castedPtr); + EXPECT_EQ(varName, "casted_variable"); +} + +TEST_F(OpenACCUtilsTest, getVariableNameNotFound) { + // Create a memref without acc.var_name attribute + auto memrefTy = MemRefType::get({10}, b.getI32Type()); + OwningOpRef allocOp = + memref::AllocaOp::create(b, loc, memrefTy); + + Value varPtr = allocOp->getResult(); + + // Test that getVariableName returns empty string when no name is found + std::string varName = getVariableName(varPtr); + EXPECT_EQ(varName, ""); +} + +TEST_F(OpenACCUtilsTest, getVariableNameFromCopyin) { + // Create a memref + auto memrefTy = MemRefType::get({10}, b.getI32Type()); + OwningOpRef allocOp = + memref::AllocaOp::create(b, loc, memrefTy); + + Value varPtr = allocOp->getResult(); + StringRef name = "data_array"; + OwningOpRef copyinOp = + CopyinOp::create(b, loc, varPtr, /*structured=*/true, /*implicit=*/true, + /*name=*/name); + + // Test that getVariableName extracts the name from the copyin operation + std::string varName = getVariableName(copyinOp->getAccVar()); + EXPECT_EQ(varName, name); +}