diff --git a/CMakeLists.txt b/CMakeLists.txt index e2008ebfa..74800e934 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -41,20 +41,21 @@ if(FLAGTREE_PLUGIN) add_definitions(-D__FLAGTREE_PLUGIN__) endif() +# FLAGTREE SPEC LIB GET FUNC +function(get_flagtree_backend_lib lib_name output_lib) + set(ret FlagTree_${FLAGTREE_BACKEND}_${lib_name}) + if(NOT TARGET ${ret}) + set(ret "") + endif() + set(${output_lib} ${ret} PARENT_SCOPE) +endfunction() + project(triton) include(CTest) if (FLAGTREE_BACKEND STREQUAL "ascend") set(TRITON_ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}") set(PATCHED_TRITON_ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/third_party/ascend/triton_patch") - set(PATCHED_TRITON_LIBRARIES - "TritonIR" - ) - set(PATCHED_TRITON_DEPENDS - "TritonTableGen" - ) - include_directories(${PATCHED_TRITON_ROOT_DIR}/include) - include_directories(${PROJECT_BINARY_DIR}/third_party/ascend/triton_patch/include) # Tablegen'd files endif() if(NOT WIN32) @@ -99,6 +100,12 @@ if(TRITON_BUILD_UT) endif() # Compiler flags +## flagtree spec include dir +set(BACKEND_SPEC_INCLUDE_DIR ${PROJECT_SOURCE_DIR}/third_party/${FLAGTREE_BACKEND}/backend/flagtree_backend_specialization/include) +if(FLAGTREE_BACKEND AND EXISTS ${BACKEND_SPEC_INCLUDE_DIR}) + include_directories(${BACKEND_SPEC_INCLUDE_DIR}) +endif() +## flagtree third_party include dir set(BACKEND_INCLUDE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/third_party/${FLAGTREE_BACKEND}/include) if(FLAGTREE_BACKEND AND EXISTS "${BACKEND_INCLUDE_DIR}") include_directories(${BACKEND_INCLUDE_DIR}) @@ -134,54 +141,17 @@ function(add_triton_object name) INTERFACE $ ) - if (FLAGTREE_BACKEND STREQUAL "ascend") - set(patched_depends "") - foreach(dep ${ARG_DEPENDS}) - list(FIND PATCHED_TRITON_DEPENDS "${dep}" index) - if(index GREATER_EQUAL 0) - list(APPEND patched_depends "Patched_${dep}") - message(STATUS "Replace ${dep} by Patched_${dep} as a dependent of ${name}") - else() - list(APPEND patched_depends ${dep}) - endif() - endforeach() - if(patched_depends) - add_dependencies(${name} ${patched_depends}) - endif() - - set(patched_link_libs "") - foreach(lib ${ARG_LINK_LIBS}) - list(FIND PATCHED_TRITON_LIBRARIES "${lib}" index) - if(index GREATER_EQUAL 0) - list(APPEND patched_link_libs "Patched_${lib}") - message(STATUS "Replace ${lib} by Patched_${lib} to be linked by ${name}") - else() - list(APPEND patched_link_libs ${lib}) - endif() - endforeach() - if(patched_link_libs) - target_link_libraries(${name} PUBLIC ${patched_link_libs}) - endif() - else() - #add_library(${name} OBJECT ${ARG_UNPARSED_ARGUMENTS}) - if(ARG_DEPENDS) - add_dependencies(${name} ${ARG_DEPENDS}) - endif() - if(ARG_LINK_LIBS) - target_link_libraries(${name} PUBLIC ${ARG_LINK_LIBS}) - endif() + #add_library(${name} OBJECT ${ARG_UNPARSED_ARGUMENTS}) + if(ARG_DEPENDS) + add_dependencies(${name} ${ARG_DEPENDS}) + endif() + if(ARG_LINK_LIBS) + target_link_libraries(${name} PUBLIC ${ARG_LINK_LIBS}) endif() endfunction(add_triton_object) set_property(GLOBAL PROPERTY TRITON_LIBS "") function(add_triton_library name) - if (FLAGTREE_BACKEND STREQUAL "ascend") - list(FIND PATCHED_TRITON_LIBRARIES "${name}" index) - if(index GREATER_EQUAL 0) - message(STATUS "Adding Patched_${name} as a lib, instead of ${name}") - return() - endif() - endif() set_property(GLOBAL APPEND PROPERTY TRITON_LIBS ${name}) add_triton_object(${name} ${ARGN}) llvm_update_compile_flags(${name}) @@ -230,11 +200,6 @@ elseif(NOT FLAGTREE_BACKEND) add_subdirectory(lib) endif() -if (FLAGTREE_BACKEND STREQUAL "ascend") - add_subdirectory(${PATCHED_TRITON_ROOT_DIR}/include) - add_subdirectory(${PATCHED_TRITON_ROOT_DIR}/lib) -endif() - # find_package(PythonLibs REQUIRED) set(TRITON_SOURCE_DIR "${CMAKE_CURRENT_SOURCE_DIR}") set(TRITON_BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}") diff --git a/include/triton/Dialect/Triton/IR/CMakeLists.txt b/include/triton/Dialect/Triton/IR/CMakeLists.txt index f682f54a1..91104b6f1 100644 --- a/include/triton/Dialect/Triton/IR/CMakeLists.txt +++ b/include/triton/Dialect/Triton/IR/CMakeLists.txt @@ -1,6 +1,10 @@ set(MLIR_BINARY_DIR ${CMAKE_BINARY_DIR}) set(LLVM_TARGET_DEFINITIONS TritonOps.td) +set(BACKEND_SPEC_TD ${BACKEND_SPEC_INCLUDE_DIR}/triton/Dialect/Triton/IR/TritonOps.td) +if(EXISTS ${BACKEND_SPEC_TD}) + set(LLVM_TARGET_DEFINITIONS ${BACKEND_SPEC_TD}) +endif() mlir_tablegen(Ops.h.inc -gen-op-decls) mlir_tablegen(Ops.cpp.inc -gen-op-defs) mlir_tablegen(OpsEnums.h.inc -gen-enum-decls) @@ -13,6 +17,10 @@ mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs) add_mlir_doc(TritonDialect TritonDialect dialects/ -gen-dialect-doc) set(LLVM_TARGET_DEFINITIONS TritonTypes.td) +set(BACKEND_SPEC_TD ${BACKEND_SPEC_INCLUDE_DIR}/triton/Dialect/Triton/IR/TritonTypes.td) +if(EXISTS ${BACKEND_SPEC_TD}) + set(LLVM_TARGET_DEFINITIONS ${BACKEND_SPEC_TD}) +endif() mlir_tablegen(Types.h.inc -gen-typedef-decls) mlir_tablegen(Types.cpp.inc -gen-typedef-defs) @@ -24,4 +32,11 @@ set(LLVM_TARGET_DEFINITIONS TritonTypeInterfaces.td) mlir_tablegen(TritonTypeInterfaces.h.inc -gen-type-interface-decls) mlir_tablegen(TritonTypeInterfaces.cpp.inc -gen-type-interface-defs) +set(BACKEND_SPEC_TD ${BACKEND_SPEC_INCLUDE_DIR}/triton/Dialect/Triton/IR/TritonOpInterfaces.td) +if(EXISTS ${BACKEND_SPEC_TD}) + set(LLVM_TARGET_DEFINITIONS ${BACKEND_SPEC_TD}) + mlir_tablegen(OpInterfaces.h.inc -gen-op-interface-decls) + mlir_tablegen(OpInterfaces.cpp.inc -gen-op-interface-defs) +endif() + add_public_tablegen_target(TritonTableGen) diff --git a/include/triton/Dialect/Triton/IR/Dialect.h b/include/triton/Dialect/Triton/IR/Dialect.h index b1f1597c5..9dfa225f8 100644 --- a/include/triton/Dialect/Triton/IR/Dialect.h +++ b/include/triton/Dialect/Triton/IR/Dialect.h @@ -17,6 +17,11 @@ #include "triton/Dialect/Triton/IR/Traits.h" #include "triton/Dialect/Triton/IR/Types.h" +#include "flagtree_spec.h" +#ifdef FLAGTREE_SPEC_Dialect_Triton_IR_OpInterfaces_head +#include "triton/Dialect/Triton/IR/OpInterfaces.h" +#endif + #define GET_OP_CLASSES #include "triton/Dialect/Triton/IR/Ops.h.inc" @@ -24,57 +29,57 @@ namespace mlir { namespace triton { struct GlobalMemory : public SideEffects::Resource::Base { - StringRef getName() final { return ""; } + StringRef getName() final { return ""; } }; class DialectInferLayoutInterface : public DialectInterface::Base { public: - DialectInferLayoutInterface(Dialect *dialect) : Base(dialect) {} - - virtual LogicalResult - inferTransOpEncoding(Attribute operandEncoding, ArrayRef order, - Attribute &resultEncoding) const = 0; - - virtual LogicalResult - inferReduceOpEncoding(Attribute operandEncoding, unsigned axis, - Attribute &resultEncoding) const = 0; - - virtual LogicalResult - inferExpandDimsOpEncoding(Attribute operandEncoding, unsigned axis, - Attribute &resultEncoding, - std::optional location) const = 0; - - // Note: This function only verifies the operand encoding. It doesn't infer - // the result encoding. - virtual LogicalResult - inferDotOpEncoding(Attribute operandEncoding, unsigned opIdx, - Attribute retEncoding, - std::optional location) const = 0; - - // Tries to compute the encoding for the result of a reshape operation that - // makes the reshape a "nop", i.e. the same GPU threads contain the same - // elements as before the reshape. Note that this is not always possible (in - // which case you'd need to choose a different layout for the input to the - // reshape). - virtual LogicalResult - inferReshapeOpNoReorderEncoding(ArrayRef srcShape, Attribute srcEnc, - ArrayRef dstShape, Attribute &dstEnc, - std::optional loc) const = 0; - - virtual LogicalResult - inferJoinOpEncoding(Attribute srcEnc, Attribute &dstEnc, - std::optional loc) const = 0; - - virtual LogicalResult - inferSplitOpEncoding(Attribute srcEnc, Attribute &dstEnc, - std::optional loc) const = 0; - - // Verify that the encoding are compatible to be used together in a dot - // operation - virtual LogicalResult - verifyDotOpEncodingCompatibility(Operation *op, Attribute operandEncodingA, - Attribute operandEncodingB) const = 0; + DialectInferLayoutInterface(Dialect *dialect) : Base(dialect) {} + + virtual LogicalResult + inferTransOpEncoding(Attribute operandEncoding, ArrayRef order, + Attribute &resultEncoding) const = 0; + + virtual LogicalResult + inferReduceOpEncoding(Attribute operandEncoding, unsigned axis, + Attribute &resultEncoding) const = 0; + + virtual LogicalResult + inferExpandDimsOpEncoding(Attribute operandEncoding, unsigned axis, + Attribute &resultEncoding, + std::optional location) const = 0; + + // Note: This function only verifies the operand encoding. It doesn't infer + // the result encoding. + virtual LogicalResult + inferDotOpEncoding(Attribute operandEncoding, unsigned opIdx, + Attribute retEncoding, + std::optional location) const = 0; + + // Tries to compute the encoding for the result of a reshape operation that + // makes the reshape a "nop", i.e. the same GPU threads contain the same + // elements as before the reshape. Note that this is not always possible (in + // which case you'd need to choose a different layout for the input to the + // reshape). + virtual LogicalResult + inferReshapeOpNoReorderEncoding(ArrayRef srcShape, Attribute srcEnc, + ArrayRef dstShape, Attribute &dstEnc, + std::optional loc) const = 0; + + virtual LogicalResult + inferJoinOpEncoding(Attribute srcEnc, Attribute &dstEnc, + std::optional loc) const = 0; + + virtual LogicalResult + inferSplitOpEncoding(Attribute srcEnc, Attribute &dstEnc, + std::optional loc) const = 0; + + // Verify that the encoding are compatible to be used together in a dot + // operation + virtual LogicalResult + verifyDotOpEncodingCompatibility(Operation *op, Attribute operandEncodingA, + Attribute operandEncodingB) const = 0; }; } // namespace triton diff --git a/lib/Dialect/Triton/IR/CMakeLists.txt b/lib/Dialect/Triton/IR/CMakeLists.txt index 752daa7ff..9cd3ee6c1 100644 --- a/lib/Dialect/Triton/IR/CMakeLists.txt +++ b/lib/Dialect/Triton/IR/CMakeLists.txt @@ -1,3 +1,8 @@ +set(_EXTRA_LINK_LIBS FlagTree_${FLAGTREE_BACKEND}_TritonDialectTritonIR) +if(NOT TARGET ${_EXTRA_LINK_LIBS}) + set(_EXTRA_LINK_LIBS "") +endif() + add_triton_library(TritonIR Dialect.cpp Ops.cpp @@ -12,4 +17,6 @@ add_triton_library(TritonIR MLIRArithDialect MLIRMathDialect MLIRSCFDialect + FlagTree_${FLAGTREE_BACKEND}_TritonDialectTritonIR + # TODO: ${_EXTRA_LINK_LIBS} is error! ) diff --git a/lib/Dialect/Triton/IR/Dialect.cpp b/lib/Dialect/Triton/IR/Dialect.cpp index dc2417712..9cd053cf5 100644 --- a/lib/Dialect/Triton/IR/Dialect.cpp +++ b/lib/Dialect/Triton/IR/Dialect.cpp @@ -15,6 +15,10 @@ #include "triton/Dialect/Triton/IR/Dialect.cpp.inc" #include "triton/Dialect/Triton/IR/TritonTypeInterfaces.cpp.inc" +#ifdef FLAGTREE_SPEC_Dialect_Triton_IR_OpInterfaces_inc +#include "triton/Dialect/Triton/IR/OpInterfaces.cpp.inc" +#endif + using namespace mlir; using namespace mlir::triton; diff --git a/lib/Dialect/Triton/IR/Ops.cpp b/lib/Dialect/Triton/IR/Ops.cpp index ffea5f3c6..83d7b393e 100644 --- a/lib/Dialect/Triton/IR/Ops.cpp +++ b/lib/Dialect/Triton/IR/Ops.cpp @@ -9,6 +9,8 @@ #include "triton/Dialect/Triton/IR/Types.h" #include "triton/Dialect/Triton/IR/Utility.h" +#include "flagtree_spec.h" + namespace mlir { namespace triton { @@ -818,6 +820,7 @@ OpFoldResult AdvanceOp::fold(FoldAdaptor adaptor) { // https://github.com/llvm/llvm-project/blob/main/mlir/lib/Dialect/Func/IR/FuncOps.cpp // We could revert it back once MLIR has a better inliner interface. //-- FuncOp -- +#ifndef FLAGTREE_SPEC_Dialect_Triton_IR_Ops_build void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name, FunctionType type, ArrayRef attrs, ArrayRef argAttrs) { @@ -834,6 +837,7 @@ void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name, builder, state, argAttrs, /*resultAttrs=*/std::nullopt, getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name)); } +#endif ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) { auto buildFuncType = @@ -912,6 +916,7 @@ LogicalResult ReturnOp::verify() { } // -- JoinOp -- +#ifndef FLAGTREE_SPEC_Dialect_Triton_IR_Ops_inferReturnTypes LogicalResult JoinOp::inferReturnTypes(MLIRContext *context, std::optional location, ValueRange operands, DictionaryAttr attributes, @@ -943,6 +948,7 @@ JoinOp::inferReturnTypes(MLIRContext *context, std::optional location, RankedTensorType::get(retShape, srcTy.getElementType(), retEnc)); return success(); } +#endif // -- SplitOp -- LogicalResult SplitOp::inferReturnTypes( diff --git a/lib/Dialect/Triton/IR/Traits.cpp b/lib/Dialect/Triton/IR/Traits.cpp index 19729aee5..88dc43270 100644 --- a/lib/Dialect/Triton/IR/Traits.cpp +++ b/lib/Dialect/Triton/IR/Traits.cpp @@ -8,6 +8,8 @@ #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "llvm/Support/ErrorHandling.h" +#include "flagtree_spec.h" + using namespace mlir; namespace ttg = mlir::triton::gpu; @@ -64,6 +66,7 @@ LogicalResult OpTrait::impl::verifySameOperandsAndResultEncoding( return verifySameOperandsEncoding(op, allowTensorPointerType); } +#ifndef FLAGTREE_SPEC_Dialect_Triton_IR_TRAITS_verifyTensorSize LogicalResult OpTrait::impl::verifyTensorSize(Operation *op) { for (auto opType : op->getOperandTypes()) { if (auto tensorType = dyn_cast(opType)) { @@ -97,6 +100,7 @@ LogicalResult OpTrait::impl::verifyTensorSize(Operation *op) { } return success(); } +#endif // Check that the Triton layouts on op's operands and return types are valid. // For example, we check that the number of warps per block in a Triton GPU diff --git a/third_party/ascend/CMakeLists.txt b/third_party/ascend/CMakeLists.txt index 8b77a38ed..2ffef563c 100644 --- a/third_party/ascend/CMakeLists.txt +++ b/third_party/ascend/CMakeLists.txt @@ -1,6 +1,8 @@ #add_subdirectory(triton-adapter triton-adapter) #add_subdirectory(test) +add_subdirectory(backend/flagtree_backend_specialization/lib) + add_triton_plugin(TritonAscend ${CMAKE_CURRENT_SOURCE_DIR}/triton_ascend.cpp) target_include_directories(TritonAscend PRIVATE ${CMAKE_SOURCE_DIR}/third_party/flir/include) diff --git a/third_party/ascend/backend/flagtree_backend_specialization/include/flagtree_spec.h b/third_party/ascend/backend/flagtree_backend_specialization/include/flagtree_spec.h new file mode 100644 index 000000000..e84620633 --- /dev/null +++ b/third_party/ascend/backend/flagtree_backend_specialization/include/flagtree_spec.h @@ -0,0 +1,8 @@ +#ifndef ASCEND_FLAGTREE_SPEC_H_ +#define ASCEND_FLAGTREE_SPEC_H_ + +#include "triton/Dialect/Triton/IR/ascend_Dialect.h" +#include "triton/Dialect/Triton/IR/ascend_Ops.h" +#include "triton/Dialect/Triton/IR/ascend_Traits.h" + +#endif // ASCEND_FLAGTREE_SPEC_H_ diff --git a/third_party/ascend/triton_patch/include/runtime/libentry/libentry.h b/third_party/ascend/backend/flagtree_backend_specialization/include/runtime/libentry/libentry.h similarity index 100% rename from third_party/ascend/triton_patch/include/runtime/libentry/libentry.h rename to third_party/ascend/backend/flagtree_backend_specialization/include/runtime/libentry/libentry.h diff --git a/third_party/ascend/triton_patch/include/triton/Dialect/Triton/IR/OpInterfaces.h b/third_party/ascend/backend/flagtree_backend_specialization/include/triton/Dialect/Triton/IR/OpInterfaces.h similarity index 100% rename from third_party/ascend/triton_patch/include/triton/Dialect/Triton/IR/OpInterfaces.h rename to third_party/ascend/backend/flagtree_backend_specialization/include/triton/Dialect/Triton/IR/OpInterfaces.h diff --git a/third_party/ascend/triton_patch/include/triton/Dialect/Triton/IR/TritonAttrDefs.td b/third_party/ascend/backend/flagtree_backend_specialization/include/triton/Dialect/Triton/IR/TritonAttrDefs.td similarity index 100% rename from third_party/ascend/triton_patch/include/triton/Dialect/Triton/IR/TritonAttrDefs.td rename to third_party/ascend/backend/flagtree_backend_specialization/include/triton/Dialect/Triton/IR/TritonAttrDefs.td diff --git a/third_party/ascend/triton_patch/include/triton/Dialect/Triton/IR/TritonOpInterfaces.td b/third_party/ascend/backend/flagtree_backend_specialization/include/triton/Dialect/Triton/IR/TritonOpInterfaces.td similarity index 100% rename from third_party/ascend/triton_patch/include/triton/Dialect/Triton/IR/TritonOpInterfaces.td rename to third_party/ascend/backend/flagtree_backend_specialization/include/triton/Dialect/Triton/IR/TritonOpInterfaces.td diff --git a/third_party/ascend/triton_patch/include/triton/Dialect/Triton/IR/TritonOps.td b/third_party/ascend/backend/flagtree_backend_specialization/include/triton/Dialect/Triton/IR/TritonOps.td similarity index 100% rename from third_party/ascend/triton_patch/include/triton/Dialect/Triton/IR/TritonOps.td rename to third_party/ascend/backend/flagtree_backend_specialization/include/triton/Dialect/Triton/IR/TritonOps.td diff --git a/third_party/ascend/triton_patch/include/triton/Dialect/Triton/IR/TritonTypes.td b/third_party/ascend/backend/flagtree_backend_specialization/include/triton/Dialect/Triton/IR/TritonTypes.td similarity index 100% rename from third_party/ascend/triton_patch/include/triton/Dialect/Triton/IR/TritonTypes.td rename to third_party/ascend/backend/flagtree_backend_specialization/include/triton/Dialect/Triton/IR/TritonTypes.td diff --git a/third_party/ascend/backend/flagtree_backend_specialization/include/triton/Dialect/Triton/IR/ascend_Dialect.h b/third_party/ascend/backend/flagtree_backend_specialization/include/triton/Dialect/Triton/IR/ascend_Dialect.h new file mode 100644 index 000000000..df869de83 --- /dev/null +++ b/third_party/ascend/backend/flagtree_backend_specialization/include/triton/Dialect/Triton/IR/ascend_Dialect.h @@ -0,0 +1,8 @@ +// TODO: When upgrading to Triton 3.4.0, remove this file and use the upstream Triton file. +#ifndef ASCEND_TRITON_DIALECT_TRITON_IR_DIALECT_H_ +#define ASCEND_TRITON_DIALECT_TRITON_IR_DIALECT_H_ + +#define FLAGTREE_SPEC_Dialect_Triton_IR_OpInterfaces_head +#define FLAGTREE_SPEC_Dialect_Triton_IR_OpInterfaces_inc + +#endif // ASCEND_TRITON_DIALECT_TRITON_IR_DIALECT_H_ diff --git a/third_party/ascend/backend/flagtree_backend_specialization/include/triton/Dialect/Triton/IR/ascend_Ops.h b/third_party/ascend/backend/flagtree_backend_specialization/include/triton/Dialect/Triton/IR/ascend_Ops.h new file mode 100644 index 000000000..b88f3d565 --- /dev/null +++ b/third_party/ascend/backend/flagtree_backend_specialization/include/triton/Dialect/Triton/IR/ascend_Ops.h @@ -0,0 +1,7 @@ +#ifndef ASCEND_TRITON_DIALECT_TRITON_IR_OPS_H_ +#define ASCEND_TRITON_DIALECT_TRITON_IR_OPS_H_ + +#define FLAGTREE_SPEC_Dialect_Triton_IR_Ops_build +#define FLAGTREE_SPEC_Dialect_Triton_IR_Ops_inferReturnTypes + +#endif // ASCEND_TRITON_DIALECT_TRITON_IR_OPS_H_ diff --git a/third_party/ascend/backend/flagtree_backend_specialization/include/triton/Dialect/Triton/IR/ascend_Traits.h b/third_party/ascend/backend/flagtree_backend_specialization/include/triton/Dialect/Triton/IR/ascend_Traits.h new file mode 100644 index 000000000..fad138a15 --- /dev/null +++ b/third_party/ascend/backend/flagtree_backend_specialization/include/triton/Dialect/Triton/IR/ascend_Traits.h @@ -0,0 +1,6 @@ +#ifndef ASCEND_TRITON_DIALECT_TRITON_IR_TRAITS_H_ +#define ASCEND_TRITON_DIALECT_TRITON_IR_TRAITS_H_ + +#define FLAGTREE_SPEC_Dialect_Triton_IR_TRAITS_verifyTensorSize + +#endif // ASCEND_TRITON_DIALECT_TRITON_IR_TRAITS_H_ diff --git a/third_party/ascend/triton_patch/lib/CMakeLists.txt b/third_party/ascend/backend/flagtree_backend_specialization/lib/CMakeLists.txt similarity index 50% rename from third_party/ascend/triton_patch/lib/CMakeLists.txt rename to third_party/ascend/backend/flagtree_backend_specialization/lib/CMakeLists.txt index 0ca0f41c5..616a173f1 100644 --- a/third_party/ascend/triton_patch/lib/CMakeLists.txt +++ b/third_party/ascend/backend/flagtree_backend_specialization/lib/CMakeLists.txt @@ -1 +1,2 @@ add_subdirectory(Dialect) +add_subdirectory(runtime) \ No newline at end of file diff --git a/third_party/ascend/backend/flagtree_backend_specialization/lib/Dialect/CMakeLists.txt b/third_party/ascend/backend/flagtree_backend_specialization/lib/Dialect/CMakeLists.txt new file mode 100644 index 000000000..8eea29534 --- /dev/null +++ b/third_party/ascend/backend/flagtree_backend_specialization/lib/Dialect/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(Triton) \ No newline at end of file diff --git a/third_party/ascend/backend/flagtree_backend_specialization/lib/Dialect/Triton/CMakeLists.txt b/third_party/ascend/backend/flagtree_backend_specialization/lib/Dialect/Triton/CMakeLists.txt new file mode 100644 index 000000000..7d59dce8e --- /dev/null +++ b/third_party/ascend/backend/flagtree_backend_specialization/lib/Dialect/Triton/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(IR) \ No newline at end of file diff --git a/third_party/ascend/triton_patch/lib/Dialect/Triton/IR/CMakeLists.txt b/third_party/ascend/backend/flagtree_backend_specialization/lib/Dialect/Triton/IR/CMakeLists.txt similarity index 68% rename from third_party/ascend/triton_patch/lib/Dialect/Triton/IR/CMakeLists.txt rename to third_party/ascend/backend/flagtree_backend_specialization/lib/Dialect/Triton/IR/CMakeLists.txt index 3b7c3746a..080083531 100644 --- a/third_party/ascend/triton_patch/lib/Dialect/Triton/IR/CMakeLists.txt +++ b/third_party/ascend/backend/flagtree_backend_specialization/lib/Dialect/Triton/IR/CMakeLists.txt @@ -1,8 +1,6 @@ -add_triton_library(Patched_TritonIR - Dialect.cpp +add_triton_library(FlagTree_ascend_TritonDialectTritonIR Ops.cpp Traits.cpp - Types.cpp DEPENDS TritonTableGen diff --git a/third_party/ascend/backend/flagtree_backend_specialization/lib/Dialect/Triton/IR/Ops.cpp b/third_party/ascend/backend/flagtree_backend_specialization/lib/Dialect/Triton/IR/Ops.cpp new file mode 100644 index 000000000..cfc6c29b0 --- /dev/null +++ b/third_party/ascend/backend/flagtree_backend_specialization/lib/Dialect/Triton/IR/Ops.cpp @@ -0,0 +1,210 @@ +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/Interfaces/FunctionImplementation.h" +#include "mlir/Interfaces/FunctionInterfaces.h" +#include "mlir/Support/LLVM.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Types.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/LogicalResult.h" + +namespace mlir { +namespace triton { + +//-- SortOp -- +LogicalResult SortOp::inferReturnTypes( + MLIRContext *context, std::optional location, ValueRange operands, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) + { + if (operands.size() != 1) { + return emitOptionalError(location, "expected exactly one operand for SortOp"); + } + + if (!isa(operands[0].getType())) { + return emitOptionalError(location, "operand must be a ranked tensor type for SortOp"); + } + + Value src = operands[0]; + auto srcTy = cast(src.getType()); + auto srcShape = srcTy.getShape(); + auto srcEnc = srcTy.getEncoding(); + + if (srcShape.empty()) { + return emitOptionalError(location, "input tensor must have rank >= 1"); + } + + Type sortedTy = RankedTensorType::get(srcShape, srcTy.getElementType(), srcEnc); + + inferredReturnTypes.push_back(sortedTy); + + return success(); +} + +//-- MakeTensorDescOp -- +void MakeTensorDescOp::build(OpBuilder &builder, OperationState &state, + Value base, ValueRange shape, ValueRange strides, + ArrayRef blockShape, + bool isSignedInteger) +{ + auto ptrTy = dyn_cast(base.getType()); + if (!ptrTy) { + llvm::report_fatal_error("Expected pointer type"); + } + auto elemTy = ptrTy.getPointeeType(); + SmallVector blockShape64(blockShape); + auto blockTy = RankedTensorType::get(blockShape64, elemTy); + auto descTy = + TensorDescType::get(builder.getContext(), blockTy, isSignedInteger); + return build(builder, state, descTy, base, shape, strides); +} + +// -- DescriptorLoadOp -- +static LogicalResult verifyDescriptorLoadStoreType(Operation *op, + TensorDescType desc, + RankedTensorType tensor) +{ + RankedTensorType block = desc.getSignlessBlockType(); + ArrayRef blockShape = block.getShape(); + ArrayRef tensorShape = tensor.getShape(); + if (blockShape.size() > tensorShape.size()) { + // Allow ranked reduced load if the leading dimensions are all 1s. + for (int i = 0; i < blockShape.size() - tensorShape.size(); ++i) { + if (blockShape[i] != 1) + return op->emitOpError( + "ranked reduce load only allowed for unit dimension leading dim."); + } + blockShape = blockShape.take_back(tensorShape.size()); + } + + if (blockShape == tensorShape && + block.getElementType() == tensor.getElementType()) { + return success(); + } + return op->emitOpError("tensor descriptor block and tensor types must match"); +} + +LogicalResult DescriptorLoadOp::verify() +{ + return verifyDescriptorLoadStoreType(*this, getDesc().getType(), getType()); +} + +// -- DescriptorStoreOp -- +LogicalResult DescriptorStoreOp::verify() +{ + return verifyDescriptorLoadStoreType(*this, getDesc().getType(), + getSrc().getType()); +} + +// The following ops, including `call`, `func`, and `return` are copied and +// modified from +// https://github.com/llvm/llvm-project/blob/main/mlir/lib/Dialect/Func/IR/FuncOps.cpp +// We could revert it back once MLIR has a better inliner interface. +//-- FuncOp -- +void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name, + FunctionType type, ArrayRef attrs, + ArrayRef argAttrs) { + state.addAttribute(SymbolTable::getSymbolAttrName(), + builder.getStringAttr(name)); + state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type)); + state.attributes.append(attrs.begin(), attrs.end()); + state.addRegion(); + + if (argAttrs.empty()) + return; + assert(type.getNumInputs() == argAttrs.size()); +#if LLVM_VERSION_MAJOR < 21 + function_interface_impl::addArgAndResultAttrs( +#else // triton_v3.3.x + call_interface_impl::addArgAndResultAttrs( +#endif + builder, state, argAttrs, /*resultAttrs=*/std::nullopt, + getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name)); +} + +// -- JoinOp -- +LogicalResult +JoinOp::inferReturnTypes(MLIRContext *context, std::optional location, + ValueRange operands, DictionaryAttr attributes, + OpaqueProperties properties, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + // These should have been checked by tablegen-generated code. + assert(operands.size() == 2); + assert(operands[0].getType() == operands[1].getType()); + assert(isa(operands[0].getType())); + assert(isa(operands[1].getType())); + + Value lhs = operands[0]; + // Value rhs = operands[1]; + auto srcTy = cast(lhs.getType()); + + SmallVector retShape(srcTy.getShape()); + retShape.push_back(2); + + Attribute srcEnc = srcTy.getEncoding(); + Attribute retEnc; + if (srcEnc) { + if (dyn_cast(&srcEnc.getDialect()) + ->inferJoinOpEncoding(srcEnc, retEnc, location) + .failed()) { + return failure(); + } + } + inferredReturnTypes.push_back( + RankedTensorType::get(retShape, srcTy.getElementType(), retEnc)); + return success(); +} + +// -- GatherOp -- +LogicalResult GatherOp::verify() { + RankedTensorType indicesTy = getIndices().getType(); + RankedTensorType srcTy = getSrc().getType(); + RankedTensorType resTy = getResult().getType(); + + if (indicesTy.getShape() != resTy.getShape()) { + return emitOpError("indices and output shapes must match"); + } + if (indicesTy.getEncoding() != resTy.getEncoding()) { + return emitOpError("indices and output encodings must match"); + } + if (srcTy.getElementType() != resTy.getElementType()) { + return emitOpError("input and output element types must match"); + } + if (srcTy.getRank() != indicesTy.getRank()) { + return emitOpError("input and indices ranks must match"); + } + if (getAxis() >= srcTy.getRank()) { + return emitOpError("gather dimension must be less than the input rank"); + } + for (int dim = 0; dim < indicesTy.getRank(); ++dim) { + if (dim == getAxis()) + continue; + if (indicesTy.getShape()[dim] != srcTy.getShape()[dim]) { + return emitOpError("indices dimension ") + << dim << " must match the corresponding input dimension"; + } + } + + return success(); +} + +LogicalResult GatherOp::inferReturnTypes( + MLIRContext *context, std::optional location, ValueRange operands, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + GatherOpAdaptor adaptor(operands, attributes, properties, regions); + auto indicesType = cast(adaptor.getIndices().getType()); + auto srcType = cast(adaptor.getSrc().getType()); + + // Shape and encoding of the indices with the element type of the src. + inferredReturnTypes.push_back( + RankedTensorType::get(indicesType.getShape(), srcType.getElementType(), + indicesType.getEncoding())); + return success(); +} + +} // namespace triton +} // namespace mlir diff --git a/third_party/ascend/backend/flagtree_backend_specialization/lib/Dialect/Triton/IR/Traits.cpp b/third_party/ascend/backend/flagtree_backend_specialization/lib/Dialect/Triton/IR/Traits.cpp new file mode 100644 index 000000000..af4653fb8 --- /dev/null +++ b/third_party/ascend/backend/flagtree_backend_specialization/lib/Dialect/Triton/IR/Traits.cpp @@ -0,0 +1,46 @@ +#include "triton/Dialect/Triton/IR/Traits.h" + +#include + +#include "mlir/IR/TypeUtilities.h" +#include "triton/Dialect/Triton/IR/Types.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "llvm/Support/ErrorHandling.h" + +using namespace mlir; +namespace ttg = mlir::triton::gpu; + +LogicalResult OpTrait::impl::verifyTensorSize(Operation *op) { + for (auto opType : op->getOperandTypes()) { + if (auto tensorType = dyn_cast(opType)) { + int64_t numElements = 1; + for (int64_t s : tensorType.getShape()) + numElements *= s; + if (numElements > maxTensorNumElements) + return op->emitError("Maximum allowed number of elements is ") + << maxTensorNumElements << ", but " << *op + << " has more than that"; + // if ((numElements & (numElements - 1)) != 0) + // return op->emitError("Number of elements must be power-of-two, but ") + // << *op << " doesn't follow the rule (" << numElements << ")" + // << " elements"; + } + } + for (auto opType : op->getResultTypes()) { + if (auto tensorType = dyn_cast(opType)) { + int64_t numElements = 1; + for (int64_t s : tensorType.getShape()) + numElements *= s; + if (numElements > maxTensorNumElements) + return op->emitError("Maximum allowed number of elements is ") + << maxTensorNumElements << ", but " << *op + << " has more than that"; + // if ((numElements & (numElements - 1)) != 0) + // return op->emitError("Number of elements must be power-of-two, but ") + // << *op << " doesn't follow the rule (" << numElements << ")" + // << " elements"; + } + } + return success(); +} diff --git a/third_party/ascend/backend/flagtree_backend_specialization/lib/runtime/CMakeLists.txt b/third_party/ascend/backend/flagtree_backend_specialization/lib/runtime/CMakeLists.txt new file mode 100644 index 000000000..4bd5505d9 --- /dev/null +++ b/third_party/ascend/backend/flagtree_backend_specialization/lib/runtime/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(libentry) \ No newline at end of file diff --git a/third_party/ascend/backend/flagtree_backend_specialization/lib/runtime/libentry/CMakeLists.txt b/third_party/ascend/backend/flagtree_backend_specialization/lib/runtime/libentry/CMakeLists.txt new file mode 100644 index 000000000..ced1e9978 --- /dev/null +++ b/third_party/ascend/backend/flagtree_backend_specialization/lib/runtime/libentry/CMakeLists.txt @@ -0,0 +1,8 @@ +add_triton_library(FlagTree_ascend_libentry + libentry.cpp +) + +target_compile_options(FlagTree_ascend_libentry PRIVATE -fexceptions -frtti) + +find_package(Python3 REQUIRED COMPONENTS Development) +target_link_libraries(FlagTree_ascend_libentry PRIVATE Python3::Python) diff --git a/third_party/ascend/triton_patch/lib/runtime/libentry/libentry.cpp b/third_party/ascend/backend/flagtree_backend_specialization/lib/runtime/libentry/libentry.cpp similarity index 98% rename from third_party/ascend/triton_patch/lib/runtime/libentry/libentry.cpp rename to third_party/ascend/backend/flagtree_backend_specialization/lib/runtime/libentry/libentry.cpp index 7374fec74..94974a4c9 100644 --- a/third_party/ascend/triton_patch/lib/runtime/libentry/libentry.cpp +++ b/third_party/ascend/backend/flagtree_backend_specialization/lib/runtime/libentry/libentry.cpp @@ -1,4 +1,4 @@ -#include "runtime/libentry/libentry.h" +#include "../../../include/runtime/libentry/libentry.h" using namespace libentry; diff --git a/third_party/ascend/triton_patch/include/CMakeLists.txt b/third_party/ascend/triton_patch/include/CMakeLists.txt deleted file mode 100644 index 109c292fe..000000000 --- a/third_party/ascend/triton_patch/include/CMakeLists.txt +++ /dev/null @@ -1 +0,0 @@ -add_subdirectory(triton) diff --git a/third_party/ascend/triton_patch/include/triton/CMakeLists.txt b/third_party/ascend/triton_patch/include/triton/CMakeLists.txt deleted file mode 100644 index 0ca0f41c5..000000000 --- a/third_party/ascend/triton_patch/include/triton/CMakeLists.txt +++ /dev/null @@ -1 +0,0 @@ -add_subdirectory(Dialect) diff --git a/third_party/ascend/triton_patch/include/triton/Dialect/CMakeLists.txt b/third_party/ascend/triton_patch/include/triton/Dialect/CMakeLists.txt deleted file mode 100644 index 5e601271e..000000000 --- a/third_party/ascend/triton_patch/include/triton/Dialect/CMakeLists.txt +++ /dev/null @@ -1 +0,0 @@ -add_subdirectory(Triton) diff --git a/third_party/ascend/triton_patch/include/triton/Dialect/Triton/CMakeLists.txt b/third_party/ascend/triton_patch/include/triton/Dialect/Triton/CMakeLists.txt deleted file mode 100644 index f33061b2d..000000000 --- a/third_party/ascend/triton_patch/include/triton/Dialect/Triton/CMakeLists.txt +++ /dev/null @@ -1 +0,0 @@ -add_subdirectory(IR) diff --git a/third_party/ascend/triton_patch/include/triton/Dialect/Triton/IR/CMakeLists.txt b/third_party/ascend/triton_patch/include/triton/Dialect/Triton/IR/CMakeLists.txt deleted file mode 100644 index 990e3b68f..000000000 --- a/third_party/ascend/triton_patch/include/triton/Dialect/Triton/IR/CMakeLists.txt +++ /dev/null @@ -1,39 +0,0 @@ -set(MLIR_BINARY_DIR ${CMAKE_BINARY_DIR}) - -file(RELATIVE_PATH patch_rel_dir "${CMAKE_SOURCE_DIR}" "${CMAKE_CURRENT_SOURCE_DIR}") -string(REPLACE "triton_patch" "third_party/triton" triton_rel_dir "${patch_rel_dir}") - set(triton_abs_dir "${TRITON_ROOT_DIR}/include/triton/Dialect/Triton/IR") # message(STATUS "triton_abs_dir: ${triton_abs_dir}") - -set(LLVM_TARGET_DEFINITIONS TritonOps.td) -mlir_tablegen(Ops.h.inc -gen-op-decls) -mlir_tablegen(Ops.cpp.inc -gen-op-defs) -mlir_tablegen(OpsEnums.h.inc -gen-enum-decls) -mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs) -# add_mlir_doc(TritonOps TritonOps dialects/ -gen-op-doc) - -set(LLVM_TARGET_DEFINITIONS ${triton_abs_dir}/TritonDialect.td) -mlir_tablegen(Dialect.h.inc -gen-dialect-decls) -mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs) -# add_mlir_doc(TritonDialect TritonDialect dialects/ -gen-dialect-doc) - -# TODO: When upgrading to Triton 3.4.0, enable the commented line below and remove the current line. -# set(LLVM_TARGET_DEFINITIONS ${triton_abs_dir}/TritonTypes.td) -set(LLVM_TARGET_DEFINITIONS TritonTypes.td) -mlir_tablegen(Types.h.inc -gen-typedef-decls) -mlir_tablegen(Types.cpp.inc -gen-typedef-defs) - -set(LLVM_TARGET_DEFINITIONS ${triton_abs_dir}/TritonInterfaces.td) -mlir_tablegen(AttrInterfaces.h.inc -gen-attr-interface-decls) -mlir_tablegen(AttrInterfaces.cpp.inc -gen-attr-interface-defs) - -set(LLVM_TARGET_DEFINITIONS ${triton_abs_dir}/TritonTypeInterfaces.td) -mlir_tablegen(TritonTypeInterfaces.h.inc -gen-type-interface-decls) -mlir_tablegen(TritonTypeInterfaces.cpp.inc -gen-type-interface-defs) - -# TODO: When upgrading to Triton 3.4.0, enable the commented line below and remove the current line. -# set(LLVM_TARGET_DEFINITIONS ${triton_abs_dir}/TritonOpInterfaces.td) -set(LLVM_TARGET_DEFINITIONS TritonOpInterfaces.td) -mlir_tablegen(OpInterfaces.h.inc -gen-op-interface-decls) -mlir_tablegen(OpInterfaces.cpp.inc -gen-op-interface-defs) - -add_public_tablegen_target(Patched_TritonTableGen) diff --git a/third_party/ascend/triton_patch/include/triton/Dialect/Triton/IR/CMakeLists.txt.bak b/third_party/ascend/triton_patch/include/triton/Dialect/Triton/IR/CMakeLists.txt.bak deleted file mode 100644 index 9b004a8bd..000000000 --- a/third_party/ascend/triton_patch/include/triton/Dialect/Triton/IR/CMakeLists.txt.bak +++ /dev/null @@ -1,40 +0,0 @@ -set(MLIR_BINARY_DIR ${CMAKE_BINARY_DIR}) - -file(RELATIVE_PATH patch_rel_dir "${CMAKE_SOURCE_DIR}" "${CMAKE_CURRENT_SOURCE_DIR}") -string(REPLACE "triton_patch" "third_party/triton" triton_rel_dir "${patch_rel_dir}") -set(triton_abs_dir "${CMAKE_SOURCE_DIR}/${triton_rel_dir}") -# message(STATUS "triton_abs_dir: ${triton_abs_dir}") - -set(LLVM_TARGET_DEFINITIONS TritonOps.td) -mlir_tablegen(Ops.h.inc -gen-op-decls) -mlir_tablegen(Ops.cpp.inc -gen-op-defs) -mlir_tablegen(OpsEnums.h.inc -gen-enum-decls) -mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs) -# add_mlir_doc(TritonOps TritonOps dialects/ -gen-op-doc) - -set(LLVM_TARGET_DEFINITIONS ${triton_abs_dir}/TritonDialect.td) -mlir_tablegen(Dialect.h.inc -gen-dialect-decls) -mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs) -# add_mlir_doc(TritonDialect TritonDialect dialects/ -gen-dialect-doc) - -# TODO: When upgrading to Triton 3.4.0, enable the commented line below and remove the current line. -# set(LLVM_TARGET_DEFINITIONS ${triton_abs_dir}/TritonTypes.td) -set(LLVM_TARGET_DEFINITIONS TritonTypes.td) -mlir_tablegen(Types.h.inc -gen-typedef-decls) -mlir_tablegen(Types.cpp.inc -gen-typedef-defs) - -set(LLVM_TARGET_DEFINITIONS ${triton_abs_dir}/TritonInterfaces.td) -mlir_tablegen(AttrInterfaces.h.inc -gen-attr-interface-decls) -mlir_tablegen(AttrInterfaces.cpp.inc -gen-attr-interface-defs) - -set(LLVM_TARGET_DEFINITIONS ${triton_abs_dir}/TritonTypeInterfaces.td) -mlir_tablegen(TritonTypeInterfaces.h.inc -gen-type-interface-decls) -mlir_tablegen(TritonTypeInterfaces.cpp.inc -gen-type-interface-defs) - -# TODO: When upgrading to Triton 3.4.0, enable the commented line below and remove the current line. -# set(LLVM_TARGET_DEFINITIONS ${triton_abs_dir}/TritonOpInterfaces.td) -set(LLVM_TARGET_DEFINITIONS TritonOpInterfaces.td) -mlir_tablegen(OpInterfaces.h.inc -gen-op-interface-decls) -mlir_tablegen(OpInterfaces.cpp.inc -gen-op-interface-defs) - -add_public_tablegen_target(Patched_TritonTableGen) diff --git a/third_party/ascend/triton_patch/include/triton/Dialect/Triton/IR/Dialect.h b/third_party/ascend/triton_patch/include/triton/Dialect/Triton/IR/Dialect.h deleted file mode 100644 index c0b0885ed..000000000 --- a/third_party/ascend/triton_patch/include/triton/Dialect/Triton/IR/Dialect.h +++ /dev/null @@ -1,75 +0,0 @@ -// TODO: When upgrading to Triton 3.4.0, remove this file and use the upstream Triton file. -#ifndef TRITON_DIALECT_TRITON_IR_DIALECT_H_ -#define TRITON_DIALECT_TRITON_IR_DIALECT_H_ - -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Math/IR/Math.h" -#include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/IR/BuiltinOps.h" - -#include "mlir/Interfaces/ControlFlowInterfaces.h" -#include "mlir/Interfaces/FunctionInterfaces.h" -#include "mlir/Interfaces/SideEffectInterfaces.h" -#include "triton/Dialect/Triton/IR/Dialect.h.inc" -#include "triton/Dialect/Triton/IR/OpInterfaces.h" -#include "triton/Dialect/Triton/IR/OpsEnums.h.inc" -#include "triton/Dialect/Triton/IR/Traits.h" -#include "triton/Dialect/Triton/IR/Types.h" -#include "mlir/IR/Dialect.h" - -#define GET_OP_CLASSES -#include "triton/Dialect/Triton/IR/Ops.h.inc" - -namespace mlir { -namespace triton { - -struct GlobalMemory : public SideEffects::Resource::Base { - StringRef getName() final { return ""; } -}; - -class DialectInferLayoutInterface : public DialectInterface::Base { -public: - DialectInferLayoutInterface(Dialect *dialect) : Base(dialect) {} - - virtual LogicalResult inferTransOpEncoding(Attribute operandEncoding, ArrayRef order, - Attribute &resultEncoding) const = 0; - - virtual LogicalResult inferReduceOpEncoding(Attribute operandEncoding, unsigned axis, - Attribute &resultEncoding) const = 0; - - virtual LogicalResult inferExpandDimsOpEncoding(Attribute operandEncoding, unsigned axis, Attribute &resultEncoding, - std::optional location) const = 0; - - // Note: This function only verifies the operand encoding. It doesn't infer - // the result encoding. - virtual LogicalResult inferDotOpEncoding(Attribute operandEncoding, unsigned opIdx, Attribute retEncoding, - std::optional location) const = 0; - - // Tries to compute the encoding for the result of a reshape operation that - // makes the reshape a "nop", i.e. the same GPU threads contain the same - // elements as before the reshape. Note that this is not always possible (in - // which case you'd need to choose a different layout for the input to the - // reshape). - virtual LogicalResult inferReshapeOpNoReorderEncoding(ArrayRef srcShape, Attribute srcEnc, - ArrayRef dstShape, Attribute &dstEnc, - std::optional loc) const = 0; - - virtual LogicalResult inferJoinOpEncoding(Attribute srcEnc, Attribute &dstEnc, - std::optional loc) const = 0; - - virtual LogicalResult inferSplitOpEncoding(Attribute srcEnc, Attribute &dstEnc, - std::optional loc) const = 0; - - // Verify that the encoding are compatible to be used together in a dot - // operation - virtual LogicalResult verifyDotOpEncodingCompatibility(Operation *op, Attribute operandEncodingA, - Attribute operandEncodingB) const = 0; -}; - -} // namespace triton -} // namespace mlir - -#endif // TRITON_IR_DIALECT_H_ diff --git a/third_party/ascend/triton_patch/lib/Dialect/CMakeLists.txt b/third_party/ascend/triton_patch/lib/Dialect/CMakeLists.txt deleted file mode 100644 index 5e601271e..000000000 --- a/third_party/ascend/triton_patch/lib/Dialect/CMakeLists.txt +++ /dev/null @@ -1 +0,0 @@ -add_subdirectory(Triton) diff --git a/third_party/ascend/triton_patch/lib/Dialect/Triton/CMakeLists.txt b/third_party/ascend/triton_patch/lib/Dialect/Triton/CMakeLists.txt deleted file mode 100644 index f33061b2d..000000000 --- a/third_party/ascend/triton_patch/lib/Dialect/Triton/CMakeLists.txt +++ /dev/null @@ -1 +0,0 @@ -add_subdirectory(IR) diff --git a/third_party/ascend/triton_patch/lib/Dialect/Triton/IR/Dialect.cpp b/third_party/ascend/triton_patch/lib/Dialect/Triton/IR/Dialect.cpp deleted file mode 100644 index 1d9c86f4d..000000000 --- a/third_party/ascend/triton_patch/lib/Dialect/Triton/IR/Dialect.cpp +++ /dev/null @@ -1,140 +0,0 @@ -#include "triton/Dialect/Triton/IR/Dialect.h" -#include "triton/Dialect/Triton/IR/Types.h" - -#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" -#include "mlir/Dialect/UB/IR/UBOps.h" -#include "triton/Dialect/Triton/IR/AttrInterfaces.h.inc" -#include "llvm/ADT/StringSwitch.h" -#include "llvm/ADT/TypeSwitch.h" -#include "llvm/Support/raw_ostream.h" - -#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" -#include "mlir/IR/DialectImplementation.h" - -#include "mlir/Transforms/InliningUtils.h" -#include "triton/Dialect/Triton/IR/Dialect.cpp.inc" -#include "triton/Dialect/Triton/IR/TritonTypeInterfaces.cpp.inc" -#include "triton/Dialect/Triton/IR/OpInterfaces.cpp.inc" - -using namespace mlir; -using namespace mlir::triton; - -//===----------------------------------------------------------------------===// -// TritonDialect Dialect Interfaces -//===----------------------------------------------------------------------===// - -namespace { -struct TritonInlinerInterface : public DialectInlinerInterface { - using DialectInlinerInterface::DialectInlinerInterface; - - bool isLegalToInline(Operation *call, Operation *callable, - bool wouldBeCloned) const final { - auto funcOp = dyn_cast(callable); - if (!funcOp) - return true; - if (funcOp->hasAttr("noinline")) - return !funcOp->getAttrOfType("noinline").getValue(); - return true; - } - - bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned, - IRMapping &valueMapping) const final { - return true; - } - - bool isLegalToInline(Operation *, Region *, bool wouldBeCloned, - IRMapping &) const final { - return true; - } - //===--------------------------------------------------------------------===// - // Transformation Hooks - //===--------------------------------------------------------------------===// - - /// Handle the given inlined terminator by replacing it with a new operation - /// as necessary. - void handleTerminator(Operation *op, Block *newDest) const final { - // Only return needs to be handled here. - auto returnOp = dyn_cast(op); - if (!returnOp) - return; - - // Replace the return with a branch to the dest. - OpBuilder builder(op); - builder.create(op->getLoc(), newDest, - returnOp.getOperands()); - op->erase(); - } - - /// Handle the given inlined terminator by replacing it with a new operation - /// as necessary. - void handleTerminator(Operation *op, ValueRange valuesToRepl) const final { - // Only return needs to be handled here. - auto returnOp = cast(op); - - // Replace the values directly with the return operands. - assert(returnOp.getNumOperands() == valuesToRepl.size()); - for (const auto &it : llvm::enumerate(returnOp.getOperands())) - valuesToRepl[it.index()].replaceAllUsesWith(it.value()); - } -}; - -struct TensorModel - : public TensorOrMemDesc::ExternalModel { - Type getElementType(Type pointer) const { - return cast(pointer).getElementType(); - } - Attribute getEncoding(Type pointer) const { - return cast(pointer).getEncoding(); - } - ArrayRef getShape(Type pointer) const { - return cast(pointer).getShape(); - } - int64_t getRank(Type pointer) const { - return cast(pointer).getRank(); - } - int64_t getElementTypeBitWidth(Type pointer) const { - return cast(pointer).getElementTypeBitWidth(); - } -}; - -struct MemDescModel - : public TensorOrMemDesc::ExternalModel { - Type getElementType(Type pointer) const { - return cast(pointer).getElementType(); - } - Attribute getEncoding(Type pointer) const { - return cast(pointer).getEncoding(); - } - ArrayRef getShape(Type pointer) const { - return cast(pointer).getShape(); - } - int64_t getRank(Type pointer) const { - return cast(pointer).getShape().size(); - } - int64_t getElementTypeBitWidth(Type pointer) const { - return cast(pointer).getElementType().getIntOrFloatBitWidth(); - } -}; - -} // namespace - -void TritonDialect::initialize() { - registerTypes(); - - addOperations< -#define GET_OP_LIST -#include "triton/Dialect/Triton/IR/Ops.cpp.inc" - >(); - - // We can also add interface here. - addInterfaces(); - - RankedTensorType::attachInterface(*getContext()); - MemDescType::attachInterface(*getContext()); -} - -Operation *TritonDialect::materializeConstant(OpBuilder &builder, - Attribute value, Type type, - Location loc) { - return arith::ConstantOp::materialize(builder, value, type, loc); -} diff --git a/third_party/ascend/triton_patch/lib/Dialect/Triton/IR/Ops.cpp b/third_party/ascend/triton_patch/lib/Dialect/Triton/IR/Ops.cpp deleted file mode 100644 index 2ab26afa9..000000000 --- a/third_party/ascend/triton_patch/lib/Dialect/Triton/IR/Ops.cpp +++ /dev/null @@ -1,1183 +0,0 @@ -#include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinAttributes.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/OperationSupport.h" -#include "mlir/Interfaces/FunctionImplementation.h" -#include "mlir/Interfaces/FunctionInterfaces.h" -#include "mlir/Support/LLVM.h" -#include "triton/Dialect/Triton/IR/Dialect.h" -#include "triton/Dialect/Triton/IR/Types.h" -#include "triton/Dialect/Triton/IR/Utility.h" -#include "llvm/ADT/StringRef.h" -#include "llvm/Support/LogicalResult.h" - -namespace mlir { -namespace triton { - -void LoadOp::getEffects( - SmallVectorImpl> - &effects) { - effects.emplace_back(MemoryEffects::Read::get(), &getPtrMutable(), - triton::GlobalMemory::get()); - if (getIsVolatile()) - effects.emplace_back(MemoryEffects::Write::get(), - SideEffects::DefaultResource::get()); -} - -} // namespace triton -} // namespace mlir - -#define GET_OP_CLASSES -#include "triton/Dialect/Triton/IR/Ops.cpp.inc" - -// enum attribute definitions -#include "triton/Dialect/Triton/IR/OpsEnums.cpp.inc" - -namespace mlir { -namespace triton { - -//-- LoadOp -- -void LoadOp::build(OpBuilder &builder, OperationState &state, Value ptr, - CacheModifier cache, EvictionPolicy evict, bool isVolatile) { - LoadOp::build(builder, state, ptr, /*mask=*/{}, /*other=*/{}, - /*boundaryCheck=*/ArrayRef{}, /*padding=*/std::nullopt, - cache, evict, isVolatile); -} - -void LoadOp::build(OpBuilder &builder, OperationState &state, Value ptr, - ArrayRef boundaryCheck, - std::optional padding, CacheModifier cache, - EvictionPolicy evict, bool isVolatile) { - LoadOp::build(builder, state, ptr, /*mask=*/{}, /*other=*/{}, boundaryCheck, - padding, cache, evict, isVolatile); -} - -void LoadOp::build(OpBuilder &builder, OperationState &state, Value ptr, - Value mask, CacheModifier cache, EvictionPolicy evict, - bool isVolatile) { - LoadOp::build(builder, state, ptr, mask, /*other=*/{}, - /*boundaryCheck=*/ArrayRef{}, - /*padding=*/std::nullopt, cache, evict, isVolatile); -} - -void LoadOp::build(OpBuilder &builder, OperationState &state, Value ptr, - Value mask, Value other, CacheModifier cache, - EvictionPolicy evict, bool isVolatile) { - LoadOp::build(builder, state, ptr, mask, other, - /*boundaryCheck=*/ArrayRef{}, - /*padding=*/std::nullopt, cache, evict, isVolatile); -} - -void LoadOp::build(OpBuilder &builder, OperationState &state, Value ptr, - Value mask, Value other, ArrayRef boundaryCheck, - std::optional padding, CacheModifier cache, - EvictionPolicy evict, bool isVolatile) { - auto paddingAttr = - padding.has_value() - ? PaddingOptionAttr::get(builder.getContext(), padding.value()) - : PaddingOptionAttr(); - LoadOp::build(builder, state, ptr, mask, other, - builder.getDenseI32ArrayAttr(boundaryCheck), paddingAttr, cache, - evict, isVolatile); -} - -// load(ptr, splat(1), ...) -> load(ptr, ...) -// load(ptr, splat(0), other, ...) -> other -struct CanonicalizeMaskedLoadPattern : public OpRewritePattern { - CanonicalizeMaskedLoadPattern(MLIRContext *context) - : OpRewritePattern(context, 1) {} - - LogicalResult matchAndRewrite(LoadOp loadOp, - PatternRewriter &rewriter) const override { - auto mask = loadOp.getMask(); - if (!mask) - return failure(); - - auto constantMask = mask.getDefiningOp(); - if (!constantMask) - return failure(); - - auto splatMask = mlir::dyn_cast(constantMask.getValue()); - if (!splatMask) - return failure(); - - if (splatMask.getSplatValue().getValue() == true) { - // mask = splat(1) - rewriter.replaceOpWithNewOp( - loadOp, loadOp.getType(), loadOp.getPtr(), Value(), Value(), - loadOp.getBoundaryCheckAttr(), loadOp.getPaddingAttr(), - loadOp.getCache(), loadOp.getEvict(), loadOp.getIsVolatile()); - } else { - // mask = splat(0) - - // If there's no "other", the value is "undef". Perhaps we want to - // optimize it in the future.x - auto otherVal = loadOp.getOther(); - if (!otherVal) - return failure(); - rewriter.replaceOp(loadOp, otherVal); - } - return success(); - } -}; - -void LoadOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { - results.add(context); -} - -//-- StoreOp -- -void StoreOp::build(OpBuilder &builder, OperationState &state, Value ptr, - Value value, CacheModifier cache, EvictionPolicy evict) { - return StoreOp::build(builder, state, ptr, value, /*mask=*/{}, - /*boundaryCheck=*/{}, cache, evict); -} - -void StoreOp::build(OpBuilder &builder, OperationState &state, Value ptr, - Value value, Value mask, CacheModifier cache, - EvictionPolicy evict) { - return StoreOp::build(builder, state, ptr, value, mask, /*boundaryCheck=*/{}, - cache, evict); -} - -void StoreOp::build(OpBuilder &builder, OperationState &state, Value ptr, - Value value, ArrayRef boundaryCheck, - CacheModifier cache, EvictionPolicy evict) { - return StoreOp::build(builder, state, ptr, value, /*mask=*/{}, - builder.getDenseI32ArrayAttr(boundaryCheck), cache, - evict); -} - -// store(ptr, value, splat(1), ...) -> store(ptr, value, ...) -// store(ptr, value, splat(0), ...) -> [none] -struct CanonicalizeMaskedStorePattern : public OpRewritePattern { - CanonicalizeMaskedStorePattern(MLIRContext *context) - : OpRewritePattern(context, 1) {} - - LogicalResult matchAndRewrite(StoreOp storeOp, - PatternRewriter &rewriter) const override { - auto mask = storeOp.getMask(); - if (!mask) - return failure(); - - auto constantMask = mask.getDefiningOp(); - if (!constantMask) - return failure(); - - auto splatMask = mlir::dyn_cast(constantMask.getValue()); - if (!splatMask) - return failure(); - - if (splatMask.getSplatValue().getValue() == true) { - // mask = splat(1) - rewriter.replaceOpWithNewOp( - storeOp, storeOp.getPtr(), storeOp.getValue(), storeOp.getCache(), - storeOp.getEvict()); - } else { - // mask = splat(0) - rewriter.eraseOp(storeOp); - } - return success(); - } -}; - -void StoreOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { - results.add(context); -} - -//-- TransOp -- -OpFoldResult TransOp::fold(FoldAdaptor adaptor) { - // transpose(x, order=[0, 1, ...]) -> x - if (isIota(getOrder())) { - return getSrc(); - } - - // transpose(transpose(x)) -> transpose(x) - if (auto innerTrans = getSrc().getDefiningOp()) { - setOrder(applyPermutation(innerTrans.getOrder(), getOrder())); - setOperand(innerTrans.getSrc()); - return getResult(); - } - - return {}; -} - -LogicalResult TransOp::inferReturnTypes( - MLIRContext *context, std::optional location, ValueRange operands, - DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, - SmallVectorImpl &inferredReturnTypes) { - // type is the same as the input - auto argTy = cast(operands[0].getType()); - auto order = properties.as()->order.asArrayRef(); - SmallVector retShape = applyPermutation(argTy.getShape(), order); - - auto retEltTy = argTy.getElementType(); - Attribute argEncoding = argTy.getEncoding(); - Attribute retEncoding; - if (argEncoding) { - Dialect &dialect = argEncoding.getDialect(); - auto inferLayoutInterface = dyn_cast(&dialect); - if (inferLayoutInterface - ->inferTransOpEncoding(argEncoding, order, retEncoding) - .failed()) { - return failure(); - } - } - if (auto memDescTy = dyn_cast(argTy)) { - inferredReturnTypes.push_back(MemDescType::get( - retShape, retEltTy, retEncoding, memDescTy.getMemorySpace(), - memDescTy.getMutableMemory())); - } else { - inferredReturnTypes.push_back( - RankedTensorType::get(retShape, retEltTy, retEncoding)); - } - return success(); -} - -//-- SortOp -- -LogicalResult SortOp::inferReturnTypes( - MLIRContext *context, std::optional location, ValueRange operands, - DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, - SmallVectorImpl &inferredReturnTypes) - { - if (operands.size() != 1) { - return emitOptionalError(location, "expected exactly one operand for SortOp"); - } - - if (!isa(operands[0].getType())) { - return emitOptionalError(location, "operand must be a ranked tensor type for SortOp"); - } - - Value src = operands[0]; - auto srcTy = cast(src.getType()); - auto srcShape = srcTy.getShape(); - auto srcEnc = srcTy.getEncoding(); - - if (srcShape.empty()) { - return emitOptionalError(location, "input tensor must have rank >= 1"); - } - - Type sortedTy = RankedTensorType::get(srcShape, srcTy.getElementType(), srcEnc); - - inferredReturnTypes.push_back(sortedTy); - - return success(); -} - -LogicalResult TransOp::verify() { - // Check that the op's `order` attribute is a permutation of the right length. - auto srcTy = getSrc().getType(); - - ArrayRef order = getOrder(); - if (order.size() != srcTy.getRank()) { - return emitError("order must have the same size as the rank of the " - "operand and result"); - } - - SmallVector sortedOrder(order); - llvm::sort(sortedOrder); - for (int32_t i = 0; i < sortedOrder.size(); i++) { - if (sortedOrder[i] != i) { - return emitError("order must be a permutation of [0, ..., rank - 1]"); - } - } - - return success(); -} - -//-- DotOp -- -LogicalResult -DotOp::inferReturnTypes(MLIRContext *context, std::optional location, - ValueRange operands, DictionaryAttr attributes, - OpaqueProperties properties, RegionRange regions, - SmallVectorImpl &inferredReturnTypes) { - // type is the same as the accumulator - auto accTy = cast(operands[2].getType()); - inferredReturnTypes.push_back(accTy); - - // verify encodings - auto aEnc = cast(operands[0].getType()).getEncoding(); - auto bEnc = cast(operands[1].getType()).getEncoding(); - auto retEnc = accTy.getEncoding(); - if (aEnc) { - assert(bEnc && retEnc); - Dialect &dialect = retEnc.getDialect(); - auto interface = dyn_cast(&dialect); - if (interface->inferDotOpEncoding(aEnc, 0, retEnc, location).failed()) - return failure(); - if (interface->inferDotOpEncoding(bEnc, 1, retEnc, location).failed()) - return failure(); - } - return success(); -} - -LogicalResult DotOp::verify() { - auto aTy = getA().getType(); - auto bTy = getB().getType(); - if (aTy.getElementType().getIntOrFloatBitWidth() != - bTy.getElementType().getIntOrFloatBitWidth()) - return emitError( - "element types of operands A and B must have same bit width"); - auto aEncoding = aTy.getEncoding(); - auto bEncoding = bTy.getEncoding(); - if (!aEncoding && !bEncoding) - return success(); - // Verify that the encodings are valid. - if (!aEncoding || !bEncoding) - return emitError("mismatching encoding between A and B operands"); - auto accTy = getC().getType(); - auto retEnc = accTy.getEncoding(); - if (!retEnc) - return emitError("miss encoding of C operand"); - Dialect &dialect = retEnc.getDialect(); - auto interface = cast(&dialect); - return interface->verifyDotOpEncodingCompatibility(getOperation(), aEncoding, - bEncoding); -} - -//-- MakeRangeOp -- -OpFoldResult MakeRangeOp::fold(FoldAdaptor adaptor) { - // make_range(start, start + 1) -> constant(start) - if (adaptor.getStart() + 1 == adaptor.getEnd()) { - auto shapedType = cast(getType()); - return SplatElementsAttr::get(shapedType, adaptor.getStartAttr()); - } - return {}; -} - -LogicalResult MakeRangeOp::verify() { - int64_t start = getStartAttr().getInt(); - int64_t end = getEndAttr().getInt(); - if (start > end) { - return this->emitOpError() << "start must be less than or equal to end"; - } - auto ty = getType(); - if (ty.getShape().size() != 1) { - return this->emitOpError() << "return type must be a 1D tensor"; - } - if (end - start != ty.getShape()[0]) { - return this->emitOpError() - << "number of elements in returned tensor, " << ty.getShape()[0] - << ", must match size of range [" << start << ", " << end - << "), which has " << end - start << " elements"; - } - if (!ty.getElementType().isInteger(32)) { - return this->emitOpError() << "returned tensor must have i32 elements"; - } - return success(); -} - -//-- ReduceOp -- -static LogicalResult -inferReduceReturnShape(RankedTensorType argTy, Type retEltTy, int axis, - SmallVectorImpl &inferredReturnTypes) { - auto retShape = argTy.getShape().vec(); - retShape.erase(retShape.begin() + axis); - if (retShape.empty()) { - // 0d-tensor -> scalar - inferredReturnTypes.push_back(retEltTy); - } else { - // nd-tensor where n >= 1 - // infer encoding - Attribute argEncoding = argTy.getEncoding(); - Attribute retEncoding; - if (argEncoding) { - Dialect &dialect = argEncoding.getDialect(); - auto inferLayoutInterface = - dyn_cast(&dialect); - if (inferLayoutInterface - ->inferReduceOpEncoding(argEncoding, axis, retEncoding) - .failed()) { - llvm::report_fatal_error("failed to infer layout for ReduceOp"); - return failure(); - } - } - // create type - inferredReturnTypes.push_back( - RankedTensorType::get(retShape, retEltTy, retEncoding)); - } - return success(); -} - -void ReduceOp::build(OpBuilder &builder, OperationState &state, - ValueRange operands, int axis) { - SmallVector inferredReturnTypes; - for (unsigned i = 0; i < operands.size(); ++i) { - auto argTy = cast(operands[i].getType()); - auto retEltTy = argTy.getElementType(); - (void)inferReduceReturnShape(argTy, retEltTy, axis, inferredReturnTypes); - } - - ReduceOp::build(builder, state, inferredReturnTypes, operands, axis); -} - -LogicalResult ReduceOp::inferReturnTypes( - MLIRContext *context, std::optional location, ValueRange operands, - DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, - SmallVectorImpl &inferredReturnTypes) { - Properties *prop = properties.as(); - int axis = prop->axis.getInt(); - for (auto arg : operands) { - auto argTy = cast(arg.getType()); - auto retEltTy = argTy.getElementType(); - if (inferReduceReturnShape(argTy, retEltTy, axis, inferredReturnTypes) - .failed()) { - return failure(); - } - } - return success(); -} - -// Helpers for Reductions and Scans -template LogicalResult verifyReduceScan(Op &op) { - if (op.getOperands().empty()) { - return op.emitOpError() << "must have at least 1 operand"; - } - if (op.getNumOperands() != op.getNumResults()) { - return op.emitOpError() << "must have the same number of inputs as outputs"; - } - - auto getElementType = [](Type ty) { - if (auto tensorType = dyn_cast(ty)) { - return tensorType.getElementType(); - } - return ty; - }; - - for (auto [opElemTy, resTy] : - llvm::zip(op.getElementTypes(), op.getResultTypes())) { - if (opElemTy != getElementType(resTy)) { - return op.emitOpError() << "operand types and result types must agree"; - } - } - return success(); -} - -template -static LogicalResult verifyRegionsImpl(Op &op) { - auto argElementTypes = op.getElementTypes(); - const auto &operands = op.getOperands(); - const auto numArgs = 2 * operands.size(); - auto &block = *op.getBody(); - if (block.getNumArguments() != numArgs) { - return op.emitOpError() << "nested block must take " << numArgs - << " arguments, but given block with " - << block.getNumArguments() << " arguments"; - } - unsigned i = 0; - const auto &blockArgTypes = block.getArgumentTypes(); - for (unsigned i = 0; i < numArgs; ++i) { - const auto &blockArgTy = blockArgTypes[i]; - const auto &argElemTy = argElementTypes[i % operands.size()]; - if (blockArgTy != argElemTy) { - return op.emitOpError() - << "type mismatch on combine operation. Expected argument " << i - << " to have type " << argElemTy << " but got " << blockArgTy; - } - } - - auto terminator = dyn_cast(block.getTerminator()); - if (!terminator) { - return op.emitOpError() - << "combine operation must be terminated " - << "with a ReduceReturnOp but got " << block.getTerminator(); - } - const auto &combineResults = terminator->getOperands(); - if (combineResults.size() != operands.size()) { - return op.emitOpError() - << "expected combine operation to return " << operands.size() - << " values but got " << combineResults.size(); - } - for (unsigned i = 0; i < combineResults.size(); ++i) { - const auto &resultTy = combineResults[i].getType(); - const auto &argElemTy = argElementTypes[i]; - if (resultTy != argElemTy) { - return op.emitOpError() - << "type mismatch on combine operation. Expected argument " << i - << " to have type " << argElemTy << " but got " << resultTy; - } - } - return success(); -} - -static llvm::SmallVector -getInputTypesImpl(const Operation::operand_range &operands) { - llvm::SmallVector srcTys; - srcTys.reserve(operands.size()); - for (const auto &ty : operands.getTypes()) { - srcTys.push_back(cast(ty)); - } - return srcTys; -} - -static llvm::SmallVector -getElementTypesImpl(const Operation::operand_range &operands) { - llvm::SmallVector srcElemTys; - srcElemTys.reserve(operands.size()); - for (const auto &op : operands) { - srcElemTys.push_back(cast(op.getType()).getElementType()); - } - return srcElemTys; -} - -LogicalResult ReduceOp::verify() { return verifyReduceScan(*this); } - -LogicalResult ReduceOp::verifyRegions() { - return verifyRegionsImpl(*this); -} - -llvm::SmallVector ReduceOp::getInputTypes() { - return getInputTypesImpl(this->getOperands()); -} - -llvm::SmallVector ReduceOp::getElementTypes() { - return getElementTypesImpl(this->getOperands()); -} - -unsigned ReduceOp::getNumOperands() { return this->getOperands().size(); } - -//-- ScanOp -- -void ScanOp::build(OpBuilder &builder, OperationState &state, - ValueRange operands, int axis, bool reverse) { - SmallVector inferredReturnTypes; - state.addAttribute("reverse", builder.getBoolAttr(reverse)); - for (auto arg : operands) - inferredReturnTypes.push_back(arg.getType()); - ReduceOp::build(builder, state, inferredReturnTypes, operands, axis); -} - -LogicalResult -ScanOp::inferReturnTypes(MLIRContext *context, std::optional location, - ValueRange operands, DictionaryAttr attributes, - OpaqueProperties properties, RegionRange regions, - SmallVectorImpl &inferredReturnTypes) { - for (auto arg : operands) - inferredReturnTypes.push_back(arg.getType()); - return success(); -} - -LogicalResult ScanOp::verify() { return verifyReduceScan(*this); } - -LogicalResult ScanOp::verifyRegions() { - return verifyRegionsImpl(*this); -} - -llvm::SmallVector ScanOp::getInputTypes() { - return getInputTypesImpl(this->getOperands()); -} - -llvm::SmallVector ScanOp::getElementTypes() { - return getElementTypesImpl(this->getOperands()); -} - -unsigned ScanOp::getNumOperands() { return this->getOperands().size(); } - -//-- SplatOp -- -OpFoldResult SplatOp::fold(FoldAdaptor adaptor) { - auto value = adaptor.getSrc(); - if (!value) - return {}; - if (!isa(value)) - return {}; - auto shapedType = cast(getType()); - auto ret = SplatElementsAttr::get(shapedType, ArrayRef(value)); - return ret; -} - -//-- ExpandDimsOp -- -LogicalResult ExpandDimsOp::inferReturnTypes( - MLIRContext *context, std::optional loc, ValueRange operands, - DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, - SmallVectorImpl &inferredReturnTypes) { - // infer shape - auto arg = operands[0]; - auto argTy = cast(arg.getType()); - auto retShape = argTy.getShape().vec(); - Properties *prop = properties.as(); - int axis = prop->axis.getInt(); - retShape.insert(retShape.begin() + axis, 1); - // infer encoding - Attribute argEncoding = argTy.getEncoding(); - Attribute retEncoding; - if (argEncoding) { - Dialect &dialect = argEncoding.getDialect(); - auto inferLayoutInterface = dyn_cast(&dialect); - if (inferLayoutInterface - ->inferExpandDimsOpEncoding(argEncoding, axis, retEncoding, loc) - .failed()) - return emitOptionalError(loc, "failed to infer layout for ExpandDimsOp"); - } - // create type - auto argEltTy = argTy.getElementType(); - inferredReturnTypes.push_back( - RankedTensorType::get(retShape, argEltTy, retEncoding)); - return success(); -} - -LogicalResult ExpandDimsOp::canonicalize(ExpandDimsOp op, - PatternRewriter &rewriter) { - auto definingOp = op.getSrc().getDefiningOp(); - if (!definingOp) { - return failure(); - } - // expand_dims(splat) -> splat - if (auto splat = dyn_cast(definingOp)) { - rewriter.replaceOpWithNewOp(op, op.getType(), splat.getSrc()); - return success(); - } - // expand_dims(broadcast(x)) -> broadcast(expand_dims(x)) - // - // On its own this doesn't do much, but consider - // broadcast(expand_dims(broadcast)) - // -> broadcast(broadcast(expand_dims)) - // -> broadcast(expand_dims) - if (auto broadcast = dyn_cast(definingOp)) { - auto src = broadcast.getSrc(); - auto srcTy = src.getType(); - SmallVector newExpandShape(srcTy.getShape()); - newExpandShape.insert(newExpandShape.begin() + op.getAxis(), 1); - - // Infer the encoding of the new expand op, if encodings are present. - Attribute newExpandEnc; - if (auto srcEnc = srcTy.getEncoding()) { - if (dyn_cast(&srcEnc.getDialect()) - ->inferExpandDimsOpEncoding(srcEnc, op.getAxis(), newExpandEnc, - op.getLoc()) - .failed()) { - return emitOptionalError(op.getLoc(), - "failed to infer layout for ExpandDimsOp"); - } - } - - auto newExpandTy = RankedTensorType::get( - newExpandShape, srcTy.getElementType(), newExpandEnc); - auto newExpand = rewriter.create(op.getLoc(), newExpandTy, - src, op.getAxis()); - auto newBroadcast = rewriter.create( - broadcast.getLoc(), op.getType(), newExpand.getResult()); - rewriter.replaceOp(op, {newBroadcast.getResult()}); - return success(); - } - - return failure(); -} - -template -static OpFoldResult foldViewLikeOp(ViewLikeOp op, Attribute value) { - if (!value) - return {}; - - auto shapedType = cast(op.getType()); - if (auto denseElemsAttr = dyn_cast(value)) { - if (denseElemsAttr.isSplat()) { - return denseElemsAttr.resizeSplat(shapedType); - } else { - return denseElemsAttr.reshape(shapedType); - } - } - return {}; -} - -OpFoldResult ExpandDimsOp::fold(FoldAdaptor adaptor) { - return foldViewLikeOp(*this, adaptor.getSrc()); -} - -//-- ReshapeOp -- -template -LogicalResult canonicalizeViewOrBroadcast(OpType op, - PatternRewriter &rewriter) { - auto definingOp = op.getSrc().getDefiningOp(); - if (!definingOp) { - return failure(); - } - - // view(view) -> view - if (auto parentView = dyn_cast(definingOp)) { - rewriter.replaceOpWithNewOp(op, TypeRange({op.getType()}), - parentView->getOperands(), - parentView->getAttrs()); - return success(); - } - - // view(splat) -> splat - if (auto splat = dyn_cast(definingOp)) { - rewriter.replaceOpWithNewOp(op, op.getType(), splat.getSrc()); - return success(); - } - - return failure(); -} - -LogicalResult ReshapeOp::canonicalize(ReshapeOp op, PatternRewriter &rewriter) { - if (!op.getAllowReorder() || op.getEfficientLayout()) - return failure(); - return canonicalizeViewOrBroadcast(op, rewriter); -} - -OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) { - if (getType() == getSrc().getType()) { - // no-op - return getSrc(); - } - - return foldViewLikeOp(*this, adaptor.getSrc()); -} - -LogicalResult ReshapeOp::verify() { - auto dstTy = getType(); - auto srcTy = getSrc().getType(); - if (getType().getNumElements() != srcTy.getNumElements()) { - return emitError( - "number of src and dst elements of reshape must be the same"); - } - - Attribute srcEnc = srcTy.getEncoding(); - Attribute dstEnc = dstTy.getEncoding(); - if (!!srcEnc != !!dstEnc) { - return emitError("Op requires that either (a) src and dst both have " - "encodings, or (b) neither does."); - } - - if (srcEnc && !getAllowReorder()) { - Attribute inferredDstEnc; - if (cast(&srcEnc.getDialect()) - ->inferReshapeOpNoReorderEncoding(srcTy.getShape(), srcEnc, - dstTy.getShape(), inferredDstEnc, - getLoc()) - .failed()) { - return emitError("This reshape is impossible without reordering, but " - "reordering is not allowed. Try choosing a different " - "encoding for the input tensor (or allow reordering)."); - } - if (inferredDstEnc != dstEnc) { - return emitError("Expected result encoding ") - << inferredDstEnc << " but was " << dstEnc; - } - } - - return success(); -} - -//-- FpToFpOp -- -LogicalResult FpToFpOp::verify() { - auto dstType = getType().getElementType(); - auto srcType = getSrc().getType().getElementType(); - if ((dstType.getIntOrFloatBitWidth() < srcType.getIntOrFloatBitWidth()) && - (!getRounding().has_value())) { - return emitError("Rounding mode is required for FP downcast"); - } - return success(); -} - -//-- BroadcastOp -- -LogicalResult BroadcastOp::canonicalize(BroadcastOp op, - PatternRewriter &rewriter) { - return canonicalizeViewOrBroadcast(op, rewriter); -} - -OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) { - if (getType() == getSrc().getType()) { - // no-op - return getSrc(); - } - - auto value = adaptor.getSrc(); - if (!value) - return {}; - - if (auto denseElemsAttr = dyn_cast(value)) { - auto shapedType = cast(getType()); - return denseElemsAttr.resizeSplat(shapedType); - } - return {}; -} - -LogicalResult BroadcastOp::verify() { - auto src = getSrc(); - auto srcTensorType = cast(src.getType()); - auto srcShape = srcTensorType.getShape(); - auto result = getResult(); - auto resultTensorType = cast(result.getType()); - auto resultShape = resultTensorType.getShape(); - if (srcShape.size() != resultShape.size()) { - return emitError("rank of source must be same as rank of result"); - } - for (int i = 0; i < srcShape.size(); i++) { - if (srcShape[i] != 1 && srcShape[i] != resultShape[i]) { - return emitError("Different dimensions at index ") - << i << " between source and result. " - << "Broadcast requires the source dimension to be 1."; - } - } - return success(); -} - -//-- MakeTensorPtrOp -- -void MakeTensorPtrOp::build(OpBuilder &builder, OperationState &state, - Value base, ValueRange shape, ValueRange strides, - ValueRange offsets, ArrayRef tensorShape, - ArrayRef order) -{ - // Get pointer type from `base` - auto pointerType = cast(base.getType()); - assert(pointerType != nullptr); - - // Build type `tt.ptr>` - auto tensorType = RankedTensorType::get( - SmallVector(tensorShape.begin(), tensorShape.end()), - pointerType.getPointeeType()); - auto result = PointerType::get(tensorType, 1); - - return build(builder, state, result, base, shape, strides, offsets, - builder.getDenseI32ArrayAttr(order)); -} - -//-- AdvanceOp -- -OpFoldResult AdvanceOp::fold(FoldAdaptor adaptor) { - // advance(ptr, 0, 0) -> ptr - SmallVector rawOffsets = getOffsets(); - auto offsets = getConstantIntValues(rawOffsets); - if (!offsets.has_value()) - return {}; - for (int64_t offset : offsets.value()) - if (offset != 0) - return {}; - return getPtr(); -} - -//-- MakeTensorDescOp -- -void MakeTensorDescOp::build(OpBuilder &builder, OperationState &state, - Value base, ValueRange shape, ValueRange strides, - ArrayRef blockShape, - bool isSignedInteger) -{ - auto ptrTy = dyn_cast(base.getType()); - if (!ptrTy) { - llvm::report_fatal_error("Expected pointer type"); - } - auto elemTy = ptrTy.getPointeeType(); - SmallVector blockShape64(blockShape); - auto blockTy = RankedTensorType::get(blockShape64, elemTy); - auto descTy = - TensorDescType::get(builder.getContext(), blockTy, isSignedInteger); - return build(builder, state, descTy, base, shape, strides); -} - -// -- DescriptorLoadOp -- -static LogicalResult verifyDescriptorLoadStoreType(Operation *op, - TensorDescType desc, - RankedTensorType tensor) -{ - RankedTensorType block = desc.getSignlessBlockType(); - ArrayRef blockShape = block.getShape(); - ArrayRef tensorShape = tensor.getShape(); - if (blockShape.size() > tensorShape.size()) { - // Allow ranked reduced load if the leading dimensions are all 1s. - for (int i = 0; i < blockShape.size() - tensorShape.size(); ++i) { - if (blockShape[i] != 1) - return op->emitOpError( - "ranked reduce load only allowed for unit dimension leading dim."); - } - blockShape = blockShape.take_back(tensorShape.size()); - } - - if (blockShape == tensorShape && - block.getElementType() == tensor.getElementType()) { - return success(); - } - return op->emitOpError("tensor descriptor block and tensor types must match"); -} - -LogicalResult DescriptorLoadOp::verify() -{ - return verifyDescriptorLoadStoreType(*this, getDesc().getType(), getType()); -} - -// -- DescriptorStoreOp -- -LogicalResult DescriptorStoreOp::verify() -{ - return verifyDescriptorLoadStoreType(*this, getDesc().getType(), - getSrc().getType()); -} - -// The following ops, including `call`, `func`, and `return` are copied and -// modified from -// https://github.com/llvm/llvm-project/blob/main/mlir/lib/Dialect/Func/IR/FuncOps.cpp -// We could revert it back once MLIR has a better inliner interface. -//-- FuncOp -- -void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name, - FunctionType type, ArrayRef attrs, - ArrayRef argAttrs) { - state.addAttribute(SymbolTable::getSymbolAttrName(), - builder.getStringAttr(name)); - state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type)); - state.attributes.append(attrs.begin(), attrs.end()); - state.addRegion(); - - if (argAttrs.empty()) - return; - assert(type.getNumInputs() == argAttrs.size()); -#if LLVM_VERSION_MAJOR < 21 - function_interface_impl::addArgAndResultAttrs( -#else // triton_v3.3.x - call_interface_impl::addArgAndResultAttrs( -#endif - builder, state, argAttrs, /*resultAttrs=*/std::nullopt, - getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name)); -} - -ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) { - auto buildFuncType = - [](Builder &builder, ArrayRef argTypes, ArrayRef results, - function_interface_impl::VariadicFlag, - std::string &) { return builder.getFunctionType(argTypes, results); }; - - return function_interface_impl::parseFunctionOp( - parser, result, /*allowVariadic=*/false, - getFunctionTypeAttrName(result.name), buildFuncType, - getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); -} - -void FuncOp::print(OpAsmPrinter &printer) { - function_interface_impl::printFunctionOp( - printer, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), - getArgAttrsAttrName(), getResAttrsAttrName()); -} - -// -- CallOp -- -LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) { - // Check that the callee attribute was specified. - auto fnAttr = (*this).getProperties().callee; - if (!fnAttr) - return emitOpError("requires a 'callee' symbol reference attribute"); - FuncOp fn = symbolTable.lookupNearestSymbolFrom(*this, fnAttr); - if (!fn) - return emitOpError() << "'" << fnAttr.getValue() - << "' does not reference a valid function"; - - // Verify that the operand and result types match the callee. - auto fnType = fn.getFunctionType(); - if (fnType.getNumInputs() != getNumOperands()) - return emitOpError("incorrect number of operands for callee"); - - for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) - if (getOperand(i).getType() != fnType.getInput(i)) - return emitOpError("operand type mismatch: expected operand type ") - << fnType.getInput(i) << ", but provided " - << getOperand(i).getType() << " for operand number " << i; - - if (fnType.getNumResults() != getNumResults()) - return emitOpError("incorrect number of results for callee"); - - for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) - if (getResult(i).getType() != fnType.getResult(i)) { - auto diag = emitOpError("result type mismatch at index ") << i; - diag.attachNote() << " op result types: " << getResultTypes(); - diag.attachNote() << "function result types: " << fnType.getResults(); - return diag; - } - - return success(); -} - -// -- ReturnOp -- -LogicalResult ReturnOp::verify() { - auto function = cast((*this)->getParentOp()); - - // The operand number and types must match the function signature. - const auto &results = function.getFunctionType().getResults(); - if (getNumOperands() != results.size()) - return emitOpError("has ") - << getNumOperands() << " operands, but enclosing function (@" - << function.getName() << ") returns " << results.size(); - - for (unsigned i = 0, e = results.size(); i != e; ++i) - if (getOperand(i).getType() != results[i]) - return emitError() << "type of return operand " << i << " (" - << getOperand(i).getType() - << ") doesn't match function result type (" - << results[i] << ")" - << " in function @" << function.getName(); - - return success(); -} - -// -- JoinOp -- -LogicalResult -JoinOp::inferReturnTypes(MLIRContext *context, std::optional location, - ValueRange operands, DictionaryAttr attributes, - OpaqueProperties properties, RegionRange regions, - SmallVectorImpl &inferredReturnTypes) { - // These should have been checked by tablegen-generated code. - assert(operands.size() == 2); - assert(operands[0].getType() == operands[1].getType()); - assert(isa(operands[0].getType())); - assert(isa(operands[1].getType())); - - Value lhs = operands[0]; - auto srcTy = cast(lhs.getType()); - - SmallVector retShape(srcTy.getShape()); - retShape.push_back(2); - - Attribute srcEnc = srcTy.getEncoding(); - Attribute retEnc; - if (srcEnc) { - if (dyn_cast(&srcEnc.getDialect()) - ->inferJoinOpEncoding(srcEnc, retEnc, location) - .failed()) { - return failure(); - } - } - inferredReturnTypes.push_back( - RankedTensorType::get(retShape, srcTy.getElementType(), retEnc)); - return success(); -} - -// -- SplitOp -- -LogicalResult SplitOp::inferReturnTypes( - MLIRContext *context, std::optional location, ValueRange operands, - DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, - SmallVectorImpl &inferredReturnTypes) { - // These should have been checked by tablegen-generated code. - assert(operands.size() == 1); - assert(isa(operands[0].getType())); - - Value src = operands[0]; - auto srcTy = cast(src.getType()); - auto srcShape = srcTy.getShape(); - - if (srcShape.empty() || srcShape.back() != 2) { - return emitOptionalError(location, - "last dimension of input tensor must be 2"); - } - ArrayRef retShape(srcShape.begin(), srcShape.end() - 1); - - Attribute srcEnc = srcTy.getEncoding(); - Attribute retEnc; - if (srcEnc) { - if (dyn_cast(&srcEnc.getDialect()) - ->inferSplitOpEncoding(srcEnc, retEnc, location) - .failed()) { - return failure(); - } - } - auto retTy = RankedTensorType::get(retShape, srcTy.getElementType(), retEnc); - inferredReturnTypes.push_back(retTy); - inferredReturnTypes.push_back(retTy); - return success(); -} - -// -- ElementwiseInlineAsmOp -- -void ElementwiseInlineAsmOp::getEffects( - SmallVectorImpl> - &effects) { - if (getPure()) - return; - effects.emplace_back(MemoryEffects::Write::get(), - SideEffects::DefaultResource::get()); - effects.emplace_back(MemoryEffects::Read::get(), - SideEffects::DefaultResource::get()); -} - -LogicalResult ElementwiseInlineAsmOp::verify() { - if (getNumOperands() >= 1) { - auto tensorType = dyn_cast(getOperand(0).getType()); - size_t numInputElems = tensorType ? tensorType.getNumElements() : 0; - if (numInputElems % this->getPackedElement() != 0) { - return emitError("number of input elements ") - << numInputElems - << " must be a multiple of the op's packed_element attribute, " - << getPackedElement(); - } - } - return success(); -} - -// -- ExternElementwiseOp -- -void ExternElementwiseOp::getEffects( - SmallVectorImpl> - &effects) { - if (getPure()) - return; - effects.emplace_back(MemoryEffects::Write::get(), - SideEffects::DefaultResource::get()); - effects.emplace_back(MemoryEffects::Read::get(), - SideEffects::DefaultResource::get()); -} - -Speculation::Speculatability ExternElementwiseOp::getSpeculatability() { - if (getPure()) - return Speculation::Speculatable; - return Speculation::NotSpeculatable; -} - -// -- ExperimentalTensormapCreateOp -- -LogicalResult ExperimentalTensormapCreateOp::verify() { - auto rank = getBoxDim().size(); - if (getGlobalDim().size() != rank) { - return emitError("Rank mismatch for global dim. Got") - << getGlobalDim().size() << " but expected " << rank; - } - if (getGlobalStride().size() + 1 != rank) { - return emitError("Rank mismatch for global stride. Got") - << getGlobalStride().size() << " but expected " << rank - 1; - } - if (getElementStride().size() != rank) { - return emitError("Rank mismatch for element stride. Got") - << getElementStride().size() << " but expected " << rank; - } - return success(); -} - -// -- GatherOp -- -LogicalResult GatherOp::verify() { - RankedTensorType indicesTy = getIndices().getType(); - RankedTensorType srcTy = getSrc().getType(); - RankedTensorType resTy = getResult().getType(); - - if (indicesTy.getShape() != resTy.getShape()) { - return emitOpError("indices and output shapes must match"); - } - if (indicesTy.getEncoding() != resTy.getEncoding()) { - return emitOpError("indices and output encodings must match"); - } - if (srcTy.getElementType() != resTy.getElementType()) { - return emitOpError("input and output element types must match"); - } - if (srcTy.getRank() != indicesTy.getRank()) { - return emitOpError("input and indices ranks must match"); - } - if (getAxis() >= srcTy.getRank()) { - return emitOpError("gather dimension must be less than the input rank"); - } - for (int dim = 0; dim < indicesTy.getRank(); ++dim) { - if (dim == getAxis()) - continue; - if (indicesTy.getShape()[dim] != srcTy.getShape()[dim]) { - return emitOpError("indices dimension ") - << dim << " must match the corresponding input dimension"; - } - } - - return success(); -} - -LogicalResult GatherOp::inferReturnTypes( - MLIRContext *context, std::optional location, ValueRange operands, - DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, - SmallVectorImpl &inferredReturnTypes) { - GatherOpAdaptor adaptor(operands, attributes, properties, regions); - auto indicesType = cast(adaptor.getIndices().getType()); - auto srcType = cast(adaptor.getSrc().getType()); - - // Shape and encoding of the indices with the element type of the src. - inferredReturnTypes.push_back( - RankedTensorType::get(indicesType.getShape(), srcType.getElementType(), - indicesType.getEncoding())); - return success(); -} - -} // namespace triton -} // namespace mlir diff --git a/third_party/ascend/triton_patch/lib/Dialect/Triton/IR/Traits.cpp b/third_party/ascend/triton_patch/lib/Dialect/Triton/IR/Traits.cpp deleted file mode 100644 index b43a9b56c..000000000 --- a/third_party/ascend/triton_patch/lib/Dialect/Triton/IR/Traits.cpp +++ /dev/null @@ -1,239 +0,0 @@ -#include "triton/Dialect/Triton/IR/Traits.h" - -#include - -#include "mlir/IR/TypeUtilities.h" -#include "triton/Dialect/Triton/IR/Types.h" -#include "triton/Dialect/Triton/IR/Utility.h" -#include "triton/Dialect/TritonGPU/IR/Dialect.h" -#include "llvm/Support/ErrorHandling.h" - -using namespace mlir; -namespace ttg = mlir::triton::gpu; - -static LogicalResult verifySameEncoding(Type typeA, Type typeB, - bool allowTensorPointerType) { - // TODO(Keren): the allowTensorPointerType argument is a hack to allow. - // The type checking code is kind of a mess with the current design. - auto getEncoding = [=](Type type) -> Attribute { - Attribute ret; - if (auto tensorType = dyn_cast(type)) { - ret = tensorType.getEncoding(); - } - if (!allowTensorPointerType) { - assert(!triton::isTensorPointerType(type)); - } - return ret; - }; - auto encodingA = getEncoding(typeA); - auto encodingB = getEncoding(typeB); - if (!encodingA || !encodingB) - return success(); - return encodingA == encodingB ? success() : failure(); -} - -LogicalResult -OpTrait::impl::verifySameOperandsEncoding(Operation *op, - bool allowTensorPointerType) { - if (failed(verifyAtLeastNOperands(op, 1))) - return failure(); - - auto type = op->getOperand(0).getType(); - for (auto opType : llvm::drop_begin(op->getOperandTypes(), 1)) - if (failed(verifySameEncoding(opType, type, allowTensorPointerType))) - return op->emitOpError() << "requires the same encoding for all operands"; - - return success(); -} - -LogicalResult OpTrait::impl::verifySameOperandsAndResultEncoding( - Operation *op, bool allowTensorPointerType) { - if (op->getNumOperands() == 0) - return success(); - - if (failed(verifyAtLeastNOperands(op, 1)) || - failed(verifyAtLeastNResults(op, 1))) - return failure(); - - auto type = op->getOperand(0).getType(); - for (auto resultType : op->getResultTypes()) - if (failed(verifySameEncoding(resultType, type, allowTensorPointerType))) - return op->emitOpError() - << "requires the same encoding for all operands and results"; - - return verifySameOperandsEncoding(op, allowTensorPointerType); -} - -LogicalResult OpTrait::impl::verifyTensorSize(Operation *op) { - for (auto opType : op->getOperandTypes()) { - if (auto tensorType = dyn_cast(opType)) { - int64_t numElements = 1; - for (int64_t s : tensorType.getShape()) - numElements *= s; - if (numElements > maxTensorNumElements) - return op->emitError("Maximum allowed number of elements is ") - << maxTensorNumElements << ", but " << *op - << " has more than that"; - // if ((numElements & (numElements - 1)) != 0) - // return op->emitError("Number of elements must be power-of-two, but ") - // << *op << " doesn't follow the rule (" << numElements << ")" - // << " elements"; - } - } - for (auto opType : op->getResultTypes()) { - if (auto tensorType = dyn_cast(opType)) { - int64_t numElements = 1; - for (int64_t s : tensorType.getShape()) - numElements *= s; - if (numElements > maxTensorNumElements) - return op->emitError("Maximum allowed number of elements is ") - << maxTensorNumElements << ", but " << *op - << " has more than that"; - // if ((numElements & (numElements - 1)) != 0) - // return op->emitError("Number of elements must be power-of-two, but ") - // << *op << " doesn't follow the rule (" << numElements << ")" - // << " elements"; - } - } - return success(); -} - -// Check that the Triton layouts on op's operands and return types are valid. -// For example, we check that the number of warps per block in a Triton GPU -// blocked layout matches that of its module. -// -// It's a little weird to check these properties of a layout only when the -// layout is used in an op, since most of the properties don't actually depend -// on the op. They do depend on the *module*, though, and a layout is attached -// to a module only by virtue of being used in one of the module's ops. -LogicalResult OpTrait::impl::verifyTensorLayouts(Operation *op) { - auto module = op->getParentOfType(); - auto checkLayout = [&](Value val, auto makeErr) -> LogicalResult { - // Only ranked tensors can have layouts. - auto rankedTy = dyn_cast(val.getType()); - if (!rankedTy) - return success(); - - mlir::Attribute layout = rankedTy.getEncoding(); - if (!layout) - return success(); - - if (isa(layout)) - return makeErr() << "Shared layout is not allowed on tensor type."; - // TODO(jlebar): Currently this only checks blocked layouts, but other - // layouts also have invariants! - - // TODO(jlebar): Handle the case when the encoding is nested within tt.ptr. - if (auto blocked = dyn_cast(layout)) { - // A different verifier should have checked that the layout itself is - // valid, including that threads-per-warp has the same rank as - // warps-per-block etc. - auto layoutRank = blocked.getThreadsPerWarp().size(); - if (layoutRank != rankedTy.getRank()) { - return makeErr() << layout << ".\nLayout has rank " << layoutRank - << ", but the tensor it's attached to has rank " - << rankedTy.getRank() << "."; - } - - int moduleThreadsPerWarp = - ttg::TritonGPUDialect::getThreadsPerWarp(module); - int64_t layoutThreadsPerWarp = product(blocked.getThreadsPerWarp()); - if (layoutThreadsPerWarp != moduleThreadsPerWarp) { - return makeErr() << layout << ".\nLayout has a total of " - << layoutThreadsPerWarp - << " threads per warp, but the module specifies " - << moduleThreadsPerWarp << " threads per warp."; - } - - int moduleWarpsPerCTA = ttg::TritonGPUDialect::getNumWarps(module); - int64_t layoutWarpsPerCTA = product(blocked.getWarpsPerCTA()); - if (layoutWarpsPerCTA != moduleWarpsPerCTA) { - return makeErr() << layout << ".\nLayout has a total of " - << layoutWarpsPerCTA - << " warps per CTA, but the module specifies " - << moduleWarpsPerCTA << " warps per CTA."; - } - - if (blocked.getCTALayout().getCTAsPerCGA().size() > 0) { - int moduleCTAsPerCGA = ttg::TritonGPUDialect::getNumCTAs(module); - int64_t layoutCTAsPerCGA = - product(blocked.getCTALayout().getCTAsPerCGA()); - if (layoutCTAsPerCGA != moduleCTAsPerCGA) { - return makeErr() << layout << ".\nLayout has a total of " - << layoutCTAsPerCGA - << " CTAs per CGA, but the module specifies " - << moduleCTAsPerCGA << " CTAs per CGA."; - } - } - } - - return success(); - }; - - for (size_t i = 0; i < op->getNumOperands(); i++) { - auto operand = op->getOperand(i); - auto err = checkLayout(operand, [&]() { - // Stringify the operand using `printAsOperand`. This prints e.g. "%42" - // rather than the full definition. - std::string operandStr; - llvm::raw_string_ostream os(operandStr); - // If we don't assume verified, dump() will recursively call this - // function! - operand.printAsOperand(os, OpPrintingFlags().assumeVerified()); - - return op->emitError("Operand ") - << i << " (" << operand << ") has an invalid layout: "; - }); - if (!err.succeeded()) - return err; - } - - for (size_t i = 0; i < op->getNumResults(); i++) { - auto result = op->getResult(i); - auto err = checkLayout(result, [&]() { - if (op->getNumResults() == 1) { - return op->emitError("Result has an invalid layout: "); - } else { - return op->emitError("Result ") << i << " has an invalid layout: "; - } - }); - if (!err.succeeded()) - return err; - } - - return success(); -} - -static ArrayRef getTypeShape(Type type) { - auto rankedType = dyn_cast(type); - if (auto ptrType = dyn_cast(type)) - rankedType = dyn_cast(ptrType.getPointeeType()); - return rankedType ? rankedType.getShape() : ArrayRef(); -} - -LogicalResult OpTrait::impl::verifySameLoadStoreOperandsShape(Operation *op) { - if (failed(verifyAtLeastNOperands(op, 1))) - return failure(); - - auto firstOperandShape = getTypeShape(op->getOperand(0).getType()); - for (auto type : llvm::drop_begin(op->getOperandTypes(), 1)) - if (failed(verifyCompatibleShape(getTypeShape(type), firstOperandShape))) - return op->emitOpError() << "requires the same shape for all operands"; - - return success(); -} - -LogicalResult -OpTrait::impl::verifySameLoadStoreOperandsAndResultShape(Operation *op) { - if (failed(verifyAtLeastNOperands(op, 1)) || - failed(verifyAtLeastNResults(op, 1))) - return failure(); - - auto firstOperandShape = getTypeShape(op->getOperand(0).getType()); - for (auto type : op->getResultTypes()) - if (failed(verifyCompatibleShape(getTypeShape(type), firstOperandShape))) - return op->emitOpError() - << "requires the same shape for all operands and results"; - - return verifySameLoadStoreOperandsShape(op); -} diff --git a/third_party/ascend/triton_patch/lib/Dialect/Triton/IR/Types.cpp b/third_party/ascend/triton_patch/lib/Dialect/Triton/IR/Types.cpp deleted file mode 100644 index 6e41e70a8..000000000 --- a/third_party/ascend/triton_patch/lib/Dialect/Triton/IR/Types.cpp +++ /dev/null @@ -1,197 +0,0 @@ -#include "triton/Dialect/Triton/IR/Types.h" - -#include "mlir/IR/DialectImplementation.h" // required by `Types.cpp.inc` -#include "mlir/IR/TypeUtilities.h" -#include "mlir/Support/LLVM.h" -#include "triton/Dialect/Triton/IR/Dialect.h" -#include "llvm/ADT/TypeSwitch.h" // required by `Types.cpp.inc` - -using namespace mlir; -using namespace mlir::triton; - -#define GET_TYPEDEF_CLASSES -#include "triton/Dialect/Triton/IR/Types.cpp.inc" - -//===----------------------------------------------------------------------===// -// Triton Dialect -//===----------------------------------------------------------------------===// -void TritonDialect::registerTypes() { - addTypes< -#define GET_TYPEDEF_LIST -#include "triton/Dialect/Triton/IR/Types.cpp.inc" - >(); -} - -Type PointerType::parse(AsmParser &parser) { - if (parser.parseLess()) - return Type(); - - Type pointeeType; - if (parser.parseType(pointeeType)) - return Type(); - - int addressSpace = 1; - if (succeeded(parser.parseOptionalComma())) { - if (parser.parseInteger(addressSpace)) - return Type(); - } - - if (parser.parseGreater()) - return Type(); - - return PointerType::get(pointeeType, addressSpace); -} - -void PointerType::print(AsmPrinter &printer) const { - if (getAddressSpace() == 1) { - printer << "<" << getPointeeType() << ">"; - } else { - printer << "<" << getPointeeType() << ", " << getAddressSpace() << ">"; - } -} - -static constexpr llvm::StringRef kMutableMemory = "mutable"; - -Type MemDescType::parse(AsmParser &parser) { - if (parser.parseLess()) - return Type(); - - SmallVector dimensions; - if (parser.parseDimensionList(dimensions, /*allowDynamic=*/false)) - return Type(); - - // Parse the element type. - Type elementType; - if (parser.parseType(elementType)) - return Type(); - - Attribute encoding; - if (succeeded(parser.parseOptionalComma())) { - if (parser.parseAttribute(encoding)) - return Type(); - } - bool mutableMemory = false; - Attribute memorySpace; - if (succeeded(parser.parseOptionalComma())) { - if (failed(parser.parseOptionalKeyword(kMutableMemory))) { - if (parser.parseAttribute(memorySpace)) - return Type(); - } else { - mutableMemory = true; - } - } - if (mutableMemory == false && succeeded(parser.parseOptionalComma())) { - if (parser.parseOptionalKeyword(kMutableMemory)) - return Type(); - mutableMemory = true; - } - if (parser.parseGreater()) - return Type(); - return MemDescType::get(parser.getContext(), dimensions, elementType, - encoding, memorySpace, mutableMemory); -} - -void MemDescType::print(AsmPrinter &printer) const { - printer << "<"; - for (auto dim : getShape()) - printer << dim << "x"; - printer << getElementType(); - if (getEncoding()) - printer << ", " << getEncoding(); - if (getMemorySpace()) - printer << ", " << getMemorySpace(); - if (getMutableMemory()) - printer << ", " << kMutableMemory; - printer << ">"; -} - -namespace mlir { - -namespace triton { - -unsigned getPointeeBitWidth(Type type) { - auto pointeeType = getPointeeType(type); - if (auto tensorTy = dyn_cast(pointeeType)) - return tensorTy.getElementType().getIntOrFloatBitWidth(); - return pointeeType.getIntOrFloatBitWidth(); -} - -Type getI1SameShape(Type type) { - auto i1Type = IntegerType::get(type.getContext(), 1); - if (auto tensorTy = dyn_cast(type)) - return RankedTensorType::get(tensorTy.getShape(), i1Type, - tensorTy.getEncoding()); - return i1Type; -} - -Type getPointeeType(Type type) { - if (auto tensorTy = dyn_cast(type)) { - // Tensor of pointers - auto shape = tensorTy.getShape(); - auto ptrType = dyn_cast(tensorTy.getElementType()); - Type pointeeType = ptrType.getPointeeType(); - return RankedTensorType::get(shape, pointeeType, tensorTy.getEncoding()); - } else if (auto ptrType = dyn_cast(type)) { - // scalar pointer - Type pointeeType = ptrType.getPointeeType(); - return pointeeType; - } - return type; -} - -Type getI32SameShape(Type type) { - auto i32Type = IntegerType::get(type.getContext(), 32); - if (auto tensorTy = dyn_cast(type)) - return RankedTensorType::get(tensorTy.getShape(), i32Type, - tensorTy.getEncoding()); - return i32Type; -} - -Type getPointerTypeSameShape(Type type) { - if (auto tensorTy = dyn_cast(type)) { - Type elementType = tensorTy.getElementType(); - auto shape = tensorTy.getShape(); - PointerType ptrType = PointerType::get(elementType, 1); - return RankedTensorType::get(shape, ptrType, tensorTy.getEncoding()); - } else { - return PointerType::get(type, 1); - } -} - -Type getPointerTypeToElement(Type type) { - Type elementType = getElementTypeOrSelf(type); - PointerType ptrType = PointerType::get(elementType, 1); - return ptrType; -} - -// upstream Triton only uses address space 1 for Pointer Type -Type getPointerType(Type type, int addressSpace) { - return PointerType::get(type, addressSpace); -} - -int getAddressSpace(Type type) { - if (auto ptrType = dyn_cast(type)) - return ptrType.getAddressSpace(); - return 1; -} - -bool isTensorPointerType(Type type) { - if (auto ptrType = dyn_cast(type)) - return isa(ptrType.getPointeeType()); - return false; -} - -bool isTensorOrTensorPointerType(Type type) { - return isa(type) || isTensorPointerType(type); -} - -Type getElementTypeOfTensorPointerType(Type type) { - if (auto ptrType = dyn_cast(type)) - if (auto tensorTy = dyn_cast(ptrType.getPointeeType())) - return tensorTy.getElementType(); - return {}; -} - -} // namespace triton - -} // namespace mlir