Skip to content

Commit d2c7ce4

Browse files
author
Peter Y. Yeh
committed
consolidate code with cvta_to_shared()
1 parent a80730b commit d2c7ce4

File tree

1 file changed

+9
-39
lines changed
  • torchao/csrc/cuda/sparse_marlin

1 file changed

+9
-39
lines changed

torchao/csrc/cuda/sparse_marlin/mem.h

Lines changed: 9 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -49,19 +49,10 @@ __device__ inline void cp_async4_pred_zfill(void* smem_ptr,
4949
const bool zfill = false) {
5050
const int BYTES = 16;
5151
int src_in_bytes = (zfill ? 0 : BYTES);
52-
#ifdef USE_ROCM
53-
//uint32_t smem = static_cast<uint32_t>(__builtin_amdgcn_s_getpc());
54-
//asm volatile(
55-
// "{\n"
56-
// " .reg .pred p;\n"
57-
// " setp.ne.b32 p, %0, 0;\n"
58-
// " @p cp.async [%1], [%2], %3;\n" // AMD ROCm equivalent
59-
// "}\n" ::"r"((int)pred),
60-
// "r"(smem), "l"(glob_ptr), "n"(BYTES), "r"(src_in_bytes));
6152
uint32_t smem = cvta_to_shared(smem_ptr);
53+
#ifdef USE_ROCM
6254
__builtin_amdgcn_global_load_lds(static_cast<const uint32_t*>(glob_ptr), &smem, BYTES, 0, 0);
6355
#else
64-
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
6556
asm volatile(
6657
"{\n"
6758
" .reg .pred p;\n"
@@ -75,19 +66,10 @@ __device__ inline void cp_async4_pred_zfill(void* smem_ptr,
7566
__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr,
7667
bool pred = true) {
7768
const int BYTES = 16;
78-
#ifdef USE_ROCM
79-
//uint32_t smem = static_cast<uint32_t>(__builtin_amdgcn_s_getpc());
80-
//asm volatile(
81-
// "{\n"
82-
// " .reg .pred p;\n"
83-
// " setp.ne.b32 p, %0, 0;\n"
84-
// " @p ds_read_b128 %1, %2 offset:0;\n" // AMD ROCm equivalent
85-
// "}\n" ::"r"((int)pred),
86-
// "r"(smem), "l"(glob_ptr));
8769
uint32_t smem = cvta_to_shared(smem_ptr);
70+
#ifdef USE_ROCM
8871
__builtin_amdgcn_global_load_lds(static_cast<const uint32_t*>(glob_ptr), &smem, BYTES, 0, 0);
8972
#else
90-
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
9173
asm volatile(
9274
"{\n"
9375
" .reg .pred p;\n"
@@ -101,18 +83,10 @@ __device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr,
10183
// Asynchronous global->shared copy
10284
__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) {
10385
const int BYTES = 16;
104-
#ifdef USE_ROCM
105-
//uint32_t smem = static_cast<uint32_t>(__builtin_amdgcn_s_getpc());
106-
//asm volatile(
107-
// "{\n"
108-
// " ds_read_b128 %0, %1 offset:0;\n"
109-
// "}\n" ::"r"(smem),
110-
// "l"(glob_ptr));
11186
uint32_t smem = cvta_to_shared(smem_ptr);
87+
#ifdef USE_ROCM
11288
__builtin_amdgcn_global_load_lds(static_cast<const uint32_t*>(glob_ptr), &smem, BYTES, 0, 0);
113-
11489
#else
115-
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
11690
asm volatile(
11791
"{\n"
11892
" cp.async.cg.shared.global [%0], [%1], %2;\n"
@@ -146,17 +120,15 @@ __device__ inline void cp_async_wait() {
146120
// Instruction for loading a full 16x16 matrix fragment of operand A from shared
147121
// memory, directly in tensor core layout.
148122
__device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) {
149-
#ifdef USE_ROCM
150123
uint32_t* a = reinterpret_cast<uint32_t*>(&frag_a);
151-
uint32_t smem = static_cast<uint32_t>(__builtin_amdgcn_s_getpc());
124+
uint32_t smem = cvta_to_shared(smem_ptr);
125+
#ifdef USE_ROCM
152126
asm volatile(
153127
"ds_read_b128 %0, %1 offset:0\n"
154128
"ds_read_b128 %2, %1 offset:16\n"
155129
: "=v"(a[0]), "=v"(a[1]), "=v"(a[2]), "=v"(a[3])
156130
: "v"(smem));
157131
#else
158-
uint32_t* a = reinterpret_cast<uint32_t*>(&frag_a);
159-
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
160132
asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n"
161133
: "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3])
162134
: "r"(smem));
@@ -165,14 +137,13 @@ __device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) {
165137

166138
__device__ inline void ldsm4_m(FragM& frag_m, const void* smem_ptr) {
167139
uint32_t* a = reinterpret_cast<uint32_t*>(&frag_m);
140+
uint32_t smem = cvta_to_shared(smem_ptr);
168141
#ifdef USE_ROCM
169-
uint32_t smem = static_cast<uint32_t>(__builtin_amdgcn_s_getpc());
170142
asm volatile(
171143
"ds_read_b64 %0, %2 offset:0\n"
172144
: "=v"(a[0]), "=v"(a[1])
173145
: "v"(smem));
174146
#else
175-
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
176147
asm volatile("ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0,%1}, [%2];\n"
177148
: "=r"(a[0]), "=r"(a[1])
178149
: "r"(smem));
@@ -183,15 +154,14 @@ __device__ inline void ldsm4_m(FragM& frag_m, const void* smem_ptr) {
183154
// memory, directly in tensor core layout.
184155
__device__ inline void ldsm4_t(FragA& frag_a, const void* smem_ptr) {
185156
uint32_t* a = reinterpret_cast<uint32_t*>(&frag_a);
157+
uint32_t smem = cvta_to_shared(smem_ptr);
186158
#ifdef USE_ROCM
187-
uint32_t smem = static_cast<uint32_t>(__builtin_amdgcn_s_getpc());
188159
asm volatile(
189-
"ds_read_b128 %0, %4 offset:0\n"
190-
"ds_read_b128 %2, %4 offset:16\n"
160+
"ds_read_b128 %0, %1 offset:0\n"
161+
"ds_read_b128 %2, %1 offset:16\n"
191162
: "=v"(a[0]), "=v"(a[1]), "=v"(a[2]), "=v"(a[3])
192163
: "v"(smem));
193164
#else
194-
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
195165
asm volatile(
196166
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0,%1,%2,%3}, [%4];\n"
197167
: "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3])

0 commit comments

Comments
 (0)