@@ -49,19 +49,10 @@ __device__ inline void cp_async4_pred_zfill(void* smem_ptr,
4949 const bool zfill = false ) {
5050 const int BYTES = 16 ;
5151 int src_in_bytes = (zfill ? 0 : BYTES);
52- #ifdef USE_ROCM
53- // uint32_t smem = static_cast<uint32_t>(__builtin_amdgcn_s_getpc());
54- // asm volatile(
55- // "{\n"
56- // " .reg .pred p;\n"
57- // " setp.ne.b32 p, %0, 0;\n"
58- // " @p cp.async [%1], [%2], %3;\n" // AMD ROCm equivalent
59- // "}\n" ::"r"((int)pred),
60- // "r"(smem), "l"(glob_ptr), "n"(BYTES), "r"(src_in_bytes));
6152 uint32_t smem = cvta_to_shared (smem_ptr);
53+ #ifdef USE_ROCM
6254 __builtin_amdgcn_global_load_lds (static_cast <const uint32_t *>(glob_ptr), &smem, BYTES, 0 , 0 );
6355 #else
64- uint32_t smem = static_cast <uint32_t >(__cvta_generic_to_shared (smem_ptr));
6556 asm volatile (
6657 " {\n "
6758 " .reg .pred p;\n "
@@ -75,19 +66,10 @@ __device__ inline void cp_async4_pred_zfill(void* smem_ptr,
7566__device__ inline void cp_async4_pred (void * smem_ptr, const void * glob_ptr,
7667 bool pred = true ) {
7768 const int BYTES = 16 ;
78- #ifdef USE_ROCM
79- // uint32_t smem = static_cast<uint32_t>(__builtin_amdgcn_s_getpc());
80- // asm volatile(
81- // "{\n"
82- // " .reg .pred p;\n"
83- // " setp.ne.b32 p, %0, 0;\n"
84- // " @p ds_read_b128 %1, %2 offset:0;\n" // AMD ROCm equivalent
85- // "}\n" ::"r"((int)pred),
86- // "r"(smem), "l"(glob_ptr));
8769 uint32_t smem = cvta_to_shared (smem_ptr);
70+ #ifdef USE_ROCM
8871 __builtin_amdgcn_global_load_lds (static_cast <const uint32_t *>(glob_ptr), &smem, BYTES, 0 , 0 );
8972 #else
90- uint32_t smem = static_cast <uint32_t >(__cvta_generic_to_shared (smem_ptr));
9173 asm volatile (
9274 " {\n "
9375 " .reg .pred p;\n "
@@ -101,18 +83,10 @@ __device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr,
10183// Asynchronous global->shared copy
10284__device__ inline void cp_async4 (void * smem_ptr, const void * glob_ptr) {
10385 const int BYTES = 16 ;
104- #ifdef USE_ROCM
105- // uint32_t smem = static_cast<uint32_t>(__builtin_amdgcn_s_getpc());
106- // asm volatile(
107- // "{\n"
108- // " ds_read_b128 %0, %1 offset:0;\n"
109- // "}\n" ::"r"(smem),
110- // "l"(glob_ptr));
11186 uint32_t smem = cvta_to_shared (smem_ptr);
87+ #ifdef USE_ROCM
11288 __builtin_amdgcn_global_load_lds (static_cast <const uint32_t *>(glob_ptr), &smem, BYTES, 0 , 0 );
113-
11489 #else
115- uint32_t smem = static_cast <uint32_t >(__cvta_generic_to_shared (smem_ptr));
11690 asm volatile (
11791 " {\n "
11892 " cp.async.cg.shared.global [%0], [%1], %2;\n "
@@ -146,17 +120,15 @@ __device__ inline void cp_async_wait() {
146120// Instruction for loading a full 16x16 matrix fragment of operand A from shared
147121// memory, directly in tensor core layout.
148122__device__ inline void ldsm4 (FragA& frag_a, const void * smem_ptr) {
149- #ifdef USE_ROCM
150123 uint32_t * a = reinterpret_cast <uint32_t *>(&frag_a);
151- uint32_t smem = static_cast <uint32_t >(__builtin_amdgcn_s_getpc ());
124+ uint32_t smem = cvta_to_shared (smem_ptr);
125+ #ifdef USE_ROCM
152126 asm volatile (
153127 " ds_read_b128 %0, %1 offset:0\n "
154128 " ds_read_b128 %2, %1 offset:16\n "
155129 : " =v" (a[0 ]), " =v" (a[1 ]), " =v" (a[2 ]), " =v" (a[3 ])
156130 : " v" (smem));
157131 #else
158- uint32_t * a = reinterpret_cast <uint32_t *>(&frag_a);
159- uint32_t smem = static_cast <uint32_t >(__cvta_generic_to_shared (smem_ptr));
160132 asm volatile (" ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n "
161133 : " =r" (a[0 ]), " =r" (a[1 ]), " =r" (a[2 ]), " =r" (a[3 ])
162134 : " r" (smem));
@@ -165,14 +137,13 @@ __device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) {
165137
166138__device__ inline void ldsm4_m (FragM& frag_m, const void * smem_ptr) {
167139 uint32_t * a = reinterpret_cast <uint32_t *>(&frag_m);
140+ uint32_t smem = cvta_to_shared (smem_ptr);
168141 #ifdef USE_ROCM
169- uint32_t smem = static_cast <uint32_t >(__builtin_amdgcn_s_getpc ());
170142 asm volatile (
171143 " ds_read_b64 %0, %2 offset:0\n "
172144 : " =v" (a[0 ]), " =v" (a[1 ])
173145 : " v" (smem));
174146 #else
175- uint32_t smem = static_cast <uint32_t >(__cvta_generic_to_shared (smem_ptr));
176147 asm volatile (" ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0,%1}, [%2];\n "
177148 : " =r" (a[0 ]), " =r" (a[1 ])
178149 : " r" (smem));
@@ -183,15 +154,14 @@ __device__ inline void ldsm4_m(FragM& frag_m, const void* smem_ptr) {
183154// memory, directly in tensor core layout.
184155__device__ inline void ldsm4_t (FragA& frag_a, const void * smem_ptr) {
185156 uint32_t * a = reinterpret_cast <uint32_t *>(&frag_a);
157+ uint32_t smem = cvta_to_shared (smem_ptr);
186158 #ifdef USE_ROCM
187- uint32_t smem = static_cast <uint32_t >(__builtin_amdgcn_s_getpc ());
188159 asm volatile (
189- " ds_read_b128 %0, %4 offset:0\n "
190- " ds_read_b128 %2, %4 offset:16\n "
160+ " ds_read_b128 %0, %1 offset:0\n "
161+ " ds_read_b128 %2, %1 offset:16\n "
191162 : " =v" (a[0 ]), " =v" (a[1 ]), " =v" (a[2 ]), " =v" (a[3 ])
192163 : " v" (smem));
193164 #else
194- uint32_t smem = static_cast <uint32_t >(__cvta_generic_to_shared (smem_ptr));
195165 asm volatile (
196166 " ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0,%1,%2,%3}, [%4];\n "
197167 : " =r" (a[0 ]), " =r" (a[1 ]), " =r" (a[2 ]), " =r" (a[3 ])
0 commit comments