Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 21 additions & 56 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -41,20 +41,21 @@ if(FLAGTREE_PLUGIN)
add_definitions(-D__FLAGTREE_PLUGIN__)
endif()

# FLAGTREE SPEC LIB GET FUNC
function(get_flagtree_backend_lib lib_name output_lib)
set(ret FlagTree_${FLAGTREE_BACKEND}_${lib_name})
if(NOT TARGET ${ret})
set(ret "")
endif()
set(${output_lib} ${ret} PARENT_SCOPE)
endfunction()

project(triton)
include(CTest)

if (FLAGTREE_BACKEND STREQUAL "ascend")
set(TRITON_ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}")
set(PATCHED_TRITON_ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/third_party/ascend/triton_patch")
set(PATCHED_TRITON_LIBRARIES
"TritonIR"
)
set(PATCHED_TRITON_DEPENDS
"TritonTableGen"
)
include_directories(${PATCHED_TRITON_ROOT_DIR}/include)
include_directories(${PROJECT_BINARY_DIR}/third_party/ascend/triton_patch/include) # Tablegen'd files
endif()

if(NOT WIN32)
Expand Down Expand Up @@ -99,6 +100,12 @@ if(TRITON_BUILD_UT)
endif()

# Compiler flags
## flagtree spec include dir
set(BACKEND_SPEC_INCLUDE_DIR ${PROJECT_SOURCE_DIR}/third_party/${FLAGTREE_BACKEND}/backend/flagtree_backend_specialization/include)
if(FLAGTREE_BACKEND AND EXISTS ${BACKEND_SPEC_INCLUDE_DIR})
include_directories(${BACKEND_SPEC_INCLUDE_DIR})
endif()
## flagtree third_party include dir
set(BACKEND_INCLUDE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/third_party/${FLAGTREE_BACKEND}/include)
if(FLAGTREE_BACKEND AND EXISTS "${BACKEND_INCLUDE_DIR}")
include_directories(${BACKEND_INCLUDE_DIR})
Expand Down Expand Up @@ -134,54 +141,17 @@ function(add_triton_object name)
INTERFACE $<TARGET_OBJECTS:${name}>
)

if (FLAGTREE_BACKEND STREQUAL "ascend")
set(patched_depends "")
foreach(dep ${ARG_DEPENDS})
list(FIND PATCHED_TRITON_DEPENDS "${dep}" index)
if(index GREATER_EQUAL 0)
list(APPEND patched_depends "Patched_${dep}")
message(STATUS "Replace ${dep} by Patched_${dep} as a dependent of ${name}")
else()
list(APPEND patched_depends ${dep})
endif()
endforeach()
if(patched_depends)
add_dependencies(${name} ${patched_depends})
endif()

set(patched_link_libs "")
foreach(lib ${ARG_LINK_LIBS})
list(FIND PATCHED_TRITON_LIBRARIES "${lib}" index)
if(index GREATER_EQUAL 0)
list(APPEND patched_link_libs "Patched_${lib}")
message(STATUS "Replace ${lib} by Patched_${lib} to be linked by ${name}")
else()
list(APPEND patched_link_libs ${lib})
endif()
endforeach()
if(patched_link_libs)
target_link_libraries(${name} PUBLIC ${patched_link_libs})
endif()
else()
#add_library(${name} OBJECT ${ARG_UNPARSED_ARGUMENTS})
if(ARG_DEPENDS)
add_dependencies(${name} ${ARG_DEPENDS})
endif()
if(ARG_LINK_LIBS)
target_link_libraries(${name} PUBLIC ${ARG_LINK_LIBS})
endif()
#add_library(${name} OBJECT ${ARG_UNPARSED_ARGUMENTS})
if(ARG_DEPENDS)
add_dependencies(${name} ${ARG_DEPENDS})
endif()
if(ARG_LINK_LIBS)
target_link_libraries(${name} PUBLIC ${ARG_LINK_LIBS})
endif()
endfunction(add_triton_object)

set_property(GLOBAL PROPERTY TRITON_LIBS "")
function(add_triton_library name)
if (FLAGTREE_BACKEND STREQUAL "ascend")
list(FIND PATCHED_TRITON_LIBRARIES "${name}" index)
if(index GREATER_EQUAL 0)
message(STATUS "Adding Patched_${name} as a lib, instead of ${name}")
return()
endif()
endif()
set_property(GLOBAL APPEND PROPERTY TRITON_LIBS ${name})
add_triton_object(${name} ${ARGN})
llvm_update_compile_flags(${name})
Expand Down Expand Up @@ -230,11 +200,6 @@ elseif(NOT FLAGTREE_BACKEND)
add_subdirectory(lib)
endif()

if (FLAGTREE_BACKEND STREQUAL "ascend")
add_subdirectory(${PATCHED_TRITON_ROOT_DIR}/include)
add_subdirectory(${PATCHED_TRITON_ROOT_DIR}/lib)
endif()

# find_package(PythonLibs REQUIRED)
set(TRITON_SOURCE_DIR "${CMAKE_CURRENT_SOURCE_DIR}")
set(TRITON_BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}")
Expand Down
15 changes: 15 additions & 0 deletions include/triton/Dialect/Triton/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
set(MLIR_BINARY_DIR ${CMAKE_BINARY_DIR})

set(LLVM_TARGET_DEFINITIONS TritonOps.td)
set(BACKEND_SPEC_TD ${BACKEND_SPEC_INCLUDE_DIR}/triton/Dialect/Triton/IR/TritonOps.td)
if(EXISTS ${BACKEND_SPEC_TD})
set(LLVM_TARGET_DEFINITIONS ${BACKEND_SPEC_TD})
endif()
mlir_tablegen(Ops.h.inc -gen-op-decls)
mlir_tablegen(Ops.cpp.inc -gen-op-defs)
mlir_tablegen(OpsEnums.h.inc -gen-enum-decls)
Expand All @@ -13,6 +17,10 @@ mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs)
add_mlir_doc(TritonDialect TritonDialect dialects/ -gen-dialect-doc)

set(LLVM_TARGET_DEFINITIONS TritonTypes.td)
set(BACKEND_SPEC_TD ${BACKEND_SPEC_INCLUDE_DIR}/triton/Dialect/Triton/IR/TritonTypes.td)
if(EXISTS ${BACKEND_SPEC_TD})
set(LLVM_TARGET_DEFINITIONS ${BACKEND_SPEC_TD})
endif()
mlir_tablegen(Types.h.inc -gen-typedef-decls)
mlir_tablegen(Types.cpp.inc -gen-typedef-defs)

Expand All @@ -24,4 +32,11 @@ set(LLVM_TARGET_DEFINITIONS TritonTypeInterfaces.td)
mlir_tablegen(TritonTypeInterfaces.h.inc -gen-type-interface-decls)
mlir_tablegen(TritonTypeInterfaces.cpp.inc -gen-type-interface-defs)

set(BACKEND_SPEC_TD ${BACKEND_SPEC_INCLUDE_DIR}/triton/Dialect/Triton/IR/TritonOpInterfaces.td)
if(EXISTS ${BACKEND_SPEC_TD})
set(LLVM_TARGET_DEFINITIONS ${BACKEND_SPEC_TD})
mlir_tablegen(OpInterfaces.h.inc -gen-op-interface-decls)
mlir_tablegen(OpInterfaces.cpp.inc -gen-op-interface-defs)
endif()

add_public_tablegen_target(TritonTableGen)
97 changes: 51 additions & 46 deletions include/triton/Dialect/Triton/IR/Dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,64 +17,69 @@
#include "triton/Dialect/Triton/IR/Traits.h"
#include "triton/Dialect/Triton/IR/Types.h"

#include "flagtree_spec.h"
#ifdef FLAGTREE_SPEC_Dialect_Triton_IR_OpInterfaces_head
#include "triton/Dialect/Triton/IR/OpInterfaces.h"
#endif

#define GET_OP_CLASSES
#include "triton/Dialect/Triton/IR/Ops.h.inc"

namespace mlir {
namespace triton {

struct GlobalMemory : public SideEffects::Resource::Base<GlobalMemory> {
StringRef getName() final { return "<GlobalMemory>"; }
StringRef getName() final { return "<GlobalMemory>"; }
};

class DialectInferLayoutInterface
: public DialectInterface::Base<DialectInferLayoutInterface> {
public:
DialectInferLayoutInterface(Dialect *dialect) : Base(dialect) {}

virtual LogicalResult
inferTransOpEncoding(Attribute operandEncoding, ArrayRef<int32_t> order,
Attribute &resultEncoding) const = 0;

virtual LogicalResult
inferReduceOpEncoding(Attribute operandEncoding, unsigned axis,
Attribute &resultEncoding) const = 0;

virtual LogicalResult
inferExpandDimsOpEncoding(Attribute operandEncoding, unsigned axis,
Attribute &resultEncoding,
std::optional<Location> location) const = 0;

// Note: This function only verifies the operand encoding. It doesn't infer
// the result encoding.
virtual LogicalResult
inferDotOpEncoding(Attribute operandEncoding, unsigned opIdx,
Attribute retEncoding,
std::optional<Location> location) const = 0;

// Tries to compute the encoding for the result of a reshape operation that
// makes the reshape a "nop", i.e. the same GPU threads contain the same
// elements as before the reshape. Note that this is not always possible (in
// which case you'd need to choose a different layout for the input to the
// reshape).
virtual LogicalResult
inferReshapeOpNoReorderEncoding(ArrayRef<int64_t> srcShape, Attribute srcEnc,
ArrayRef<int64_t> dstShape, Attribute &dstEnc,
std::optional<Location> loc) const = 0;

virtual LogicalResult
inferJoinOpEncoding(Attribute srcEnc, Attribute &dstEnc,
std::optional<Location> loc) const = 0;

virtual LogicalResult
inferSplitOpEncoding(Attribute srcEnc, Attribute &dstEnc,
std::optional<Location> loc) const = 0;

// Verify that the encoding are compatible to be used together in a dot
// operation
virtual LogicalResult
verifyDotOpEncodingCompatibility(Operation *op, Attribute operandEncodingA,
Attribute operandEncodingB) const = 0;
DialectInferLayoutInterface(Dialect *dialect) : Base(dialect) {}

virtual LogicalResult
inferTransOpEncoding(Attribute operandEncoding, ArrayRef<int32_t> order,
Attribute &resultEncoding) const = 0;

virtual LogicalResult
inferReduceOpEncoding(Attribute operandEncoding, unsigned axis,
Attribute &resultEncoding) const = 0;

virtual LogicalResult
inferExpandDimsOpEncoding(Attribute operandEncoding, unsigned axis,
Attribute &resultEncoding,
std::optional<Location> location) const = 0;

// Note: This function only verifies the operand encoding. It doesn't infer
// the result encoding.
virtual LogicalResult
inferDotOpEncoding(Attribute operandEncoding, unsigned opIdx,
Attribute retEncoding,
std::optional<Location> location) const = 0;

// Tries to compute the encoding for the result of a reshape operation that
// makes the reshape a "nop", i.e. the same GPU threads contain the same
// elements as before the reshape. Note that this is not always possible (in
// which case you'd need to choose a different layout for the input to the
// reshape).
virtual LogicalResult
inferReshapeOpNoReorderEncoding(ArrayRef<int64_t> srcShape, Attribute srcEnc,
ArrayRef<int64_t> dstShape, Attribute &dstEnc,
std::optional<Location> loc) const = 0;

virtual LogicalResult
inferJoinOpEncoding(Attribute srcEnc, Attribute &dstEnc,
std::optional<Location> loc) const = 0;

virtual LogicalResult
inferSplitOpEncoding(Attribute srcEnc, Attribute &dstEnc,
std::optional<Location> loc) const = 0;

// Verify that the encoding are compatible to be used together in a dot
// operation
virtual LogicalResult
verifyDotOpEncodingCompatibility(Operation *op, Attribute operandEncodingA,
Attribute operandEncodingB) const = 0;
};

} // namespace triton
Expand Down
7 changes: 7 additions & 0 deletions lib/Dialect/Triton/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
set(_EXTRA_LINK_LIBS FlagTree_${FLAGTREE_BACKEND}_TritonDialectTritonIR)
if(NOT TARGET ${_EXTRA_LINK_LIBS})
set(_EXTRA_LINK_LIBS "")
endif()

add_triton_library(TritonIR
Dialect.cpp
Ops.cpp
Expand All @@ -12,4 +17,6 @@ add_triton_library(TritonIR
MLIRArithDialect
MLIRMathDialect
MLIRSCFDialect
FlagTree_${FLAGTREE_BACKEND}_TritonDialectTritonIR
# TODO: ${_EXTRA_LINK_LIBS} is error!
)
4 changes: 4 additions & 0 deletions lib/Dialect/Triton/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@
#include "triton/Dialect/Triton/IR/Dialect.cpp.inc"
#include "triton/Dialect/Triton/IR/TritonTypeInterfaces.cpp.inc"

#ifdef FLAGTREE_SPEC_Dialect_Triton_IR_OpInterfaces_inc
#include "triton/Dialect/Triton/IR/OpInterfaces.cpp.inc"
#endif

using namespace mlir;
using namespace mlir::triton;

Expand Down
6 changes: 6 additions & 0 deletions lib/Dialect/Triton/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
#include "triton/Dialect/Triton/IR/Types.h"
#include "triton/Dialect/Triton/IR/Utility.h"

#include "flagtree_spec.h"

namespace mlir {
namespace triton {

Expand Down Expand Up @@ -818,6 +820,7 @@ OpFoldResult AdvanceOp::fold(FoldAdaptor adaptor) {
// https://github.com/llvm/llvm-project/blob/main/mlir/lib/Dialect/Func/IR/FuncOps.cpp
// We could revert it back once MLIR has a better inliner interface.
//-- FuncOp --
#ifndef FLAGTREE_SPEC_Dialect_Triton_IR_Ops_build
void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
FunctionType type, ArrayRef<NamedAttribute> attrs,
ArrayRef<DictionaryAttr> argAttrs) {
Expand All @@ -834,6 +837,7 @@ void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
builder, state, argAttrs, /*resultAttrs=*/std::nullopt,
getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name));
}
#endif

ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
auto buildFuncType =
Expand Down Expand Up @@ -912,6 +916,7 @@ LogicalResult ReturnOp::verify() {
}

// -- JoinOp --
#ifndef FLAGTREE_SPEC_Dialect_Triton_IR_Ops_inferReturnTypes
LogicalResult
JoinOp::inferReturnTypes(MLIRContext *context, std::optional<Location> location,
ValueRange operands, DictionaryAttr attributes,
Expand Down Expand Up @@ -943,6 +948,7 @@ JoinOp::inferReturnTypes(MLIRContext *context, std::optional<Location> location,
RankedTensorType::get(retShape, srcTy.getElementType(), retEnc));
return success();
}
#endif

// -- SplitOp --
LogicalResult SplitOp::inferReturnTypes(
Expand Down
4 changes: 4 additions & 0 deletions lib/Dialect/Triton/IR/Traits.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "llvm/Support/ErrorHandling.h"

#include "flagtree_spec.h"

using namespace mlir;
namespace ttg = mlir::triton::gpu;

Expand Down Expand Up @@ -64,6 +66,7 @@ LogicalResult OpTrait::impl::verifySameOperandsAndResultEncoding(
return verifySameOperandsEncoding(op, allowTensorPointerType);
}

#ifndef FLAGTREE_SPEC_Dialect_Triton_IR_TRAITS_verifyTensorSize
LogicalResult OpTrait::impl::verifyTensorSize(Operation *op) {
for (auto opType : op->getOperandTypes()) {
if (auto tensorType = dyn_cast<RankedTensorType>(opType)) {
Expand Down Expand Up @@ -97,6 +100,7 @@ LogicalResult OpTrait::impl::verifyTensorSize(Operation *op) {
}
return success();
}
#endif

// Check that the Triton layouts on op's operands and return types are valid.
// For example, we check that the number of warps per block in a Triton GPU
Expand Down
2 changes: 2 additions & 0 deletions third_party/ascend/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#add_subdirectory(triton-adapter triton-adapter)
#add_subdirectory(test)

add_subdirectory(backend/flagtree_backend_specialization/lib)

add_triton_plugin(TritonAscend ${CMAKE_CURRENT_SOURCE_DIR}/triton_ascend.cpp)
target_include_directories(TritonAscend PRIVATE ${CMAKE_SOURCE_DIR}/third_party/flir/include)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#ifndef ASCEND_FLAGTREE_SPEC_H_
#define ASCEND_FLAGTREE_SPEC_H_

#include "triton/Dialect/Triton/IR/ascend_Dialect.h"
#include "triton/Dialect/Triton/IR/ascend_Ops.h"
#include "triton/Dialect/Triton/IR/ascend_Traits.h"

#endif // ASCEND_FLAGTREE_SPEC_H_
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
// TODO: When upgrading to Triton 3.4.0, remove this file and use the upstream Triton file.
#ifndef ASCEND_TRITON_DIALECT_TRITON_IR_DIALECT_H_
#define ASCEND_TRITON_DIALECT_TRITON_IR_DIALECT_H_

#define FLAGTREE_SPEC_Dialect_Triton_IR_OpInterfaces_head
#define FLAGTREE_SPEC_Dialect_Triton_IR_OpInterfaces_inc

#endif // ASCEND_TRITON_DIALECT_TRITON_IR_DIALECT_H_
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
#ifndef ASCEND_TRITON_DIALECT_TRITON_IR_OPS_H_
#define ASCEND_TRITON_DIALECT_TRITON_IR_OPS_H_

#define FLAGTREE_SPEC_Dialect_Triton_IR_Ops_build
#define FLAGTREE_SPEC_Dialect_Triton_IR_Ops_inferReturnTypes

#endif // ASCEND_TRITON_DIALECT_TRITON_IR_OPS_H_
Loading