Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions tensorflow/core/util/gpu_kernel_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,17 @@ Status GpuLaunchKernel(void (*function)(Ts...), dim3 grid_dim, dim3 block_dim,
return errors::Internal(cudaGetErrorString(result));
}
#elif TENSORFLOW_USE_ROCM
hipLaunchKernelGGL(function, grid_dim, block_dim, shared_memory_size_bytes,
stream, std::forward<Args>(arguments)...);
TF_RETURN_IF_CUDA_ERROR(hipGetLastError());
constexpr size_t count = sizeof...(Args);
auto tup_ = std::tuple<Args...>{arguments...};
auto tup = validateArgsCountType(function, tup_);
void* _Args[count];
pArgs<0>(tup, _Args);
auto k = reinterpret_cast<void*>(function);
auto result =
hipLaunchKernel(k, grid_dim, block_dim, _Args, shared_memory_size_bytes, stream);
if (result != hipSuccess) {
return errors::Internal(hipGetErrorString(result));
}
#endif
}
return OkStatus();
Expand Down
2 changes: 2 additions & 0 deletions tensorflow/workspace2.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ load("//third_party/ruy:workspace.bzl", ruy = "repo")
load("//third_party/sobol_data:workspace.bzl", sobol_data = "repo")
load("//third_party/systemlibs:syslibs_configure.bzl", "syslibs_configure")
load("//third_party/vulkan_headers:workspace.bzl", vulkan_headers = "repo")
load("@local_xla//third_party/rocm_device_libs:workspace.bzl", rocm_device_libs = "repo")

def _initialize_third_party():
""" Load third party repositories. See above load() statements. """
Expand All @@ -87,6 +88,7 @@ def _initialize_third_party():
ml_dtypes()
nanobind()
nasm()
rocm_device_libs()
opencl_headers()
pasta()
pybind11_abseil()
Expand Down

Large diffs are not rendered by default.

28 changes: 28 additions & 0 deletions third_party/llvm/workspace.bzl
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
"""Provides the repository macro to import LLVM."""

load("//third_party:repo.bzl", "tf_http_archive")

def repo(name):
"""Imports LLVM."""
LLVM_COMMIT = "f8287f6c373fcf993643dd6f0e30dde304c1be73"
LLVM_SHA256 = "add2841174abc79c45aa309bdf0cf631aa8f97e7a4df57dcfca57c60df27527f"

tf_http_archive(
name = name,
sha256 = LLVM_SHA256,
strip_prefix = "llvm-project-{commit}".format(commit = LLVM_COMMIT),
urls = [
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT),
"https://github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT),
],
build_file = "//third_party/llvm:llvm.BUILD",
patch_file = [
"//third_party/llvm:generated.patch", # Autogenerated, don't remove.
"//third_party/llvm:build.patch",
"//third_party/llvm:mathextras.patch",
"//third_party/llvm:toolchains.patch",
"//third_party/llvm:zstd.patch",
"//third_party/llvm:0001-clang-CodeGen-sret-args-should-always-point-to-the-a.patch",
],
link_files = {"//third_party/llvm:run_lit.sh": "mlir/run_lit.sh"},
)
1 change: 1 addition & 0 deletions third_party/xla/third_party/rocm_device_libs/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# copybara:uncomment package(default_applicable_licenses = ["//third_party/tensorflow:license"])
96 changes: 96 additions & 0 deletions third_party/xla/third_party/rocm_device_libs/build_defs.bzl
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
load("@bazel_skylib//lib:paths.bzl", "paths")

def bitcode_library(
name,
srcs = [],
hdrs = [],
file_specific_flags = {},
**kwargs
):
"""Builds a bitcode library

Args:
name: Unique name of the build rule.
srcs: List of source files (*.cl, *.ll).
hdrs: List of header files (*.h).
file_specific_flags: Per-file dict of flags to be passed to clang.
**kwargs: Attributes relevant for a common rule.
"""
clang_tool = "@llvm-project//clang:clang"
clang_include = "@llvm-raw//:clang/lib/Headers"
llvm_link_tool = "@llvm-project//llvm:llvm-link"
opt_tool = "@llvm-project//llvm:opt"
prepare_builtins_tool = ":prepare_builtins"

include_paths = dict([(paths.dirname(h), None) for h in hdrs]).keys()
includes = " ".join(["-I$(location {})".format(inc) for inc in include_paths])
flags = ("-fcolor-diagnostics -Werror -Wno-error=atomic-alignment -x cl -Xclang " +
"-cl-std=CL2.0 --target=amdgcn-amd-amdhsa -fvisibility=hidden -fomit-frame-pointer " +
"-Xclang -finclude-default-header -Xclang -fexperimental-strict-floating-point " +
"-Xclang -fdenormal-fp-math=dynamic -Xclang -Qn " +
"-nogpulib -cl-no-stdinc -Xclang -mcode-object-version=none")

link_inputs = []

for src in srcs:
filename = paths.basename(src)
(basename, _, ext) = filename.partition(".")

if (ext == "ll"):
link_inputs.append(src)
continue

out = basename + ".bc"
link_inputs.append(out)
extra_flags = " ".join(file_specific_flags.get(filename,[]))
native.genrule(
name = "compile_" + basename,
srcs = [src] + hdrs + include_paths,
outs = [out],
#TODO(rocm): Ugly hack to access bultin clang includes.
cmd = "$(location {}) -I$(execpath {}).runfiles/llvm-project/clang/staging/include/ {} {} {} -emit-llvm -c $(location {}) -o $@".format(
clang_tool, clang_tool, includes, flags, extra_flags, src),
tools = [clang_tool],
message = "Compiling {} ...".format(filename),
)

link_message = "Linking {}.bc ...".format(name)

prelink_out = name + ".link0.lib.bc"
native.genrule(
name = "prelink_" + name,
srcs = link_inputs,
outs = [prelink_out],
cmd = "$(location {}) $(SRCS) -o $@".format(llvm_link_tool),
tools = [llvm_link_tool],
message = link_message,
)

internalize_out = name + ".lib.bc"
native.genrule(
name = "internalize_" + name,
srcs = [prelink_out],
outs = [internalize_out],
cmd = "$(location {}) -internalize -only-needed $< -o $@".format(llvm_link_tool),
tools = [llvm_link_tool],
message = link_message,
)

strip_out = name + ".strip.bc"
native.genrule(
name = "strip_" + name,
srcs = [internalize_out],
outs = [strip_out],
cmd = "$(location {}) -passes=amdgpu-unify-metadata,strip -o $@ $<".format(opt_tool),
tools = [opt_tool],
message = link_message,
)

native.genrule(
name = name,
srcs = [strip_out],
outs = [name + ".bc"],
cmd = "$(location {}) -o $@ $<".format(prepare_builtins_tool),
tools = [prepare_builtins_tool],
message = link_message,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
diff --git a/utils/prepare-builtins/prepare-builtins.cpp b/utils/prepare-builtins/prepare-builtins.cpp
index 7fc9d06dab7d..2a93638c3f8f 100644
--- a/utils/prepare-builtins/prepare-builtins.cpp
+++ b/utils/prepare-builtins/prepare-builtins.cpp
@@ -73,6 +73,13 @@ int main(int argc, char **argv) {
return 1;
}

+ // Strip the OpenCL version metadata. There are a lot of linked
+ // modules in the library build, each spamming the same
+ // version. This may also report a different version than the user
+ // program is using. This should probably be uniqued when linking.
+ if (NamedMDNode *OCLVersion = M->getNamedMetadata("opencl.ocl.version"))
+ M->eraseNamedMetadata(OCLVersion);
+
// Set linkage of every external definition to linkonce_odr.
for (Module::iterator i = M->begin(), e = M->end(); i != e; ++i) {
if (!i->isDeclaration() && i->getLinkage() == GlobalValue::ExternalLinkage) {
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
load("build_defs.bzl", "bitcode_library")

licenses(["notice"])

package(default_visibility = ["//visibility:public"])

exports_files([
"LICENSE.TXT",
])

cc_binary(
name = "prepare_builtins",
srcs = glob([
"utils/prepare-builtins/*.cpp",
"utils/prepare-builtins/*.h",
]),
copts = [
"-fno-rtti -fno-exceptions",
],
deps = [
"@llvm-project//llvm:BitReader",
"@llvm-project//llvm:BitWriter",
"@llvm-project//llvm:Core",
"@llvm-project//llvm:IRReader",
"@llvm-project//llvm:Support",
],
visibility = ["//visibility:private"],
)

bitcode_library(
name = "ocml",
srcs = glob([
"ocml/src/*.cl"
]),
hdrs = glob([
"ocml/src/*.h",
"ocml/inc/*.h",
"irif/inc/*.h",
"oclc/inc/*.h",
]),
file_specific_flags = {
"native_logF.cl": ["-fapprox-func"],
"native_expF.cl": ["-fapprox-func"],
"sqrtF.cl": ["-cl-fp32-correctly-rounded-divide-sqrt"],
},
)

bitcode_library(
name = "ockl",
srcs = glob([
"ockl/src/*.cl",
"ockl/src/*.ll",
]),
hdrs = glob([
"ockl/inc/*.h",
"irif/inc/*.h",
"oclc/inc/*.h",
]),
file_specific_flags = {
"gaaf.cl": ["-munsafe-fp-atomics"],
},
)
22 changes: 22 additions & 0 deletions third_party/xla/third_party/rocm_device_libs/workspace.bzl
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
"""Provides the repository macro to import Rocm-Device-Libs"""

load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls")

def repo():
"""Imports Rocm-Device-Libs."""
LLVM_COMMIT = "0cf1859d038376421b4cd597e3df90d37cfca06e"
LLVM_SHA256 = "0374d1efa0f049d2d1c24c4d86029b006cb5594cc0a1b6a18c49fb094c29cd29"

tf_http_archive(
name = "rocm_device_libs",
sha256 = LLVM_SHA256,
strip_prefix = "llvm-project-{commit}/amd/device-libs".format(commit = LLVM_COMMIT),
urls = tf_mirror_urls("https://github.com/ROCm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT)),
build_file = "//third_party/rocm_device_libs:rocm_device_libs.BUILD",
patch_file = [
"//third_party/rocm_device_libs:prepare_builtins.patch",
],
link_files = {
"//third_party/rocm_device_libs:build_defs.bzl": "build_defs.bzl",
},
)
2 changes: 2 additions & 0 deletions third_party/xla/workspace2.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ load("//third_party/py:python_configure.bzl", "python_configure")
load("//third_party/py/ml_dtypes:workspace.bzl", ml_dtypes = "repo")
load("//third_party/pybind11_abseil:workspace.bzl", pybind11_abseil = "repo")
load("//third_party/pybind11_bazel:workspace.bzl", pybind11_bazel = "repo")
load("//third_party/rocm_device_libs:workspace.bzl", rocm_device_libs = "repo")
load("//third_party/robin_map:workspace.bzl", robin_map = "repo")
load("//third_party/shardy:workspace.bzl", shardy = "repo")
load("//third_party/stablehlo:workspace.bzl", stablehlo = "repo")
Expand Down Expand Up @@ -72,6 +73,7 @@ def _initialize_third_party():
nvshmem()
pybind11_abseil()
pybind11_bazel()
rocm_device_libs()
robin_map()
shardy()
stablehlo()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,12 +160,8 @@ absl::Status CreateTritonPipeline(
}

std::string GetLibdevicePath(const HloModuleConfig& hlo_config,
const se::DeviceDescription& device_info) {
std::string libdevice_dir = tsl::RocdlRoot();
auto compute_capability = device_info.rocm_compute_capability();
const std::string libdevice_path =
amdgpu::LibDevicePath(compute_capability.gcn_arch_name(), libdevice_dir);
return libdevice_path;
const se::DeviceDescription& device_info) {
return "__builtin__";
}

} // namespace gpu
Expand Down
27 changes: 27 additions & 0 deletions third_party/xla/xla/service/gpu/llvm_gpu_backend/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ load(
"if_oss",
"internal_visibility",
)
load("//xla:strict.default.bzl", "py_strict_binary")

package(
# copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
Expand Down Expand Up @@ -128,6 +129,31 @@ cc_library(
],
)

py_strict_binary(
name = "generate_amdgpu_device_lib_data_tool",
srcs = ["generate_amdgpu_device_lib_data_tool.py"],
)

genrule(
name = "generate_amdgpu_device_lib_data",
srcs = [
"@rocm_device_libs//:ockl",
"@rocm_device_libs//:ocml",
],
outs = ["amdgpu_device_lib_data.inc"],
cmd = "$(location {}) --llvm_link_bin $(location {}) $(SRCS) -o $@".format(
":generate_amdgpu_device_lib_data_tool", "@llvm-project//llvm:llvm-link"),
tools = [":generate_amdgpu_device_lib_data_tool", "@llvm-project//llvm:llvm-link"],
)

cc_library(
name = "amdgpu_device_lib_data",
hdrs = [
":generate_amdgpu_device_lib_data",
],
include_prefix = ".",
)

cc_library(
name = "amdgpu_backend",
srcs = [
Expand All @@ -138,6 +164,7 @@ cc_library(
],
local_defines = if_oss(["HAS_SUPPORT_FOR_LLD_AS_A_LIBRARY=1"]),
deps = [
":amdgpu_device_lib_data",
":llvm_gpu_backend",
":load_ir_module",
"//xla:util",
Expand Down
Loading