@@ -42,12 +42,15 @@ namespace __ESIMD_DNS {
4242// Provides access to sycl accessor class' private members.
4343class AccessorPrivateProxy {
4444public:
45- #ifdef __SYCL_DEVICE_ONLY__
4645 template <typename AccessorTy>
4746 static auto getNativeImageObj (const AccessorTy &Acc) {
47+ #ifdef __SYCL_DEVICE_ONLY__
4848 return Acc.getNativeImageObj ();
49- }
5049#else // __SYCL_DEVICE_ONLY__
50+ return Acc;
51+ #endif // __SYCL_DEVICE_ONLY__
52+ }
53+ #ifndef __SYCL_DEVICE_ONLY__
5154 static void *getPtr (const sycl::detail::AccessorBaseHost &Acc) {
5255 return Acc.getPtr ();
5356 }
@@ -421,18 +424,32 @@ __esimd_scatter_scaled(__ESIMD_DNS::simd_mask_storage_t<N> pred,
421424 static_assert (TySizeLog2 <= 2 );
422425 static_assert (std::is_integral<Ty>::value || TySizeLog2 == 2 );
423426
427+ // determine the original element's type size (as __esimd_scatter_scaled
428+ // requires vals to be a vector of 4-byte integers)
429+ constexpr size_t OrigSize = __ESIMD_DNS::ElemsPerAddrDecoding (TySizeLog2);
430+ using RestoredTy = __ESIMD_DNS::uint_type_t <OrigSize>;
431+
424432 sycl::detail::ESIMDDeviceInterface *I =
425433 sycl::detail::getESIMDDeviceInterface ();
426434
435+ __ESIMD_DNS::vector_type_t <RestoredTy, N> TypeAdjustedVals;
436+ if constexpr (OrigSize == 4 ) {
437+ TypeAdjustedVals = __ESIMD_DNS::bitcast<RestoredTy, Ty, N>(vals);
438+ } else {
439+ static_assert (OrigSize == 1 || OrigSize == 2 );
440+ TypeAdjustedVals = __ESIMD_DNS::convert_vector<RestoredTy, Ty, N>(vals);
441+ }
442+
427443 if (surf_ind == __ESIMD_NS::detail::SLM_BTI) {
428444 // Scattered-store for Shared Local Memory
429445 // __ESIMD_NS::detail::SLM_BTI is special binding table index for SLM
430446 assert (global_offset == 0 );
431447 char *SlmBase = I->__cm_emu_get_slm_ptr ();
432448 for (int i = 0 ; i < N; ++i) {
433449 if (pred[i]) {
434- Ty *addr = reinterpret_cast <Ty *>(elem_offsets[i] + SlmBase);
435- *addr = vals[i];
450+ RestoredTy *addr =
451+ reinterpret_cast <RestoredTy *>(elem_offsets[i] + SlmBase);
452+ *addr = TypeAdjustedVals[i];
436453 }
437454 }
438455 } else {
@@ -449,8 +466,9 @@ __esimd_scatter_scaled(__ESIMD_DNS::simd_mask_storage_t<N> pred,
449466
450467 for (int idx = 0 ; idx < N; idx++) {
451468 if (pred[idx]) {
452- Ty *addr = reinterpret_cast <Ty *>(elem_offsets[idx] + writeBase);
453- *addr = vals[idx];
469+ RestoredTy *addr =
470+ reinterpret_cast <RestoredTy *>(elem_offsets[idx] + writeBase);
471+ *addr = TypeAdjustedVals[idx];
454472 }
455473 }
456474
@@ -629,7 +647,12 @@ __esimd_gather_masked_scaled2(SurfIndAliasTy surf_ind, uint32_t global_offset,
629647{
630648 static_assert (Scale == 0 );
631649
632- __ESIMD_DNS::vector_type_t <Ty, N> retv = 0 ;
650+ // determine the original element's type size (as __esimd_scatter_scaled
651+ // requires vals to be a vector of 4-byte integers)
652+ constexpr size_t OrigSize = __ESIMD_DNS::ElemsPerAddrDecoding (TySizeLog2);
653+ using RestoredTy = __ESIMD_DNS::uint_type_t <OrigSize>;
654+
655+ __ESIMD_DNS::vector_type_t <RestoredTy, N> retv = 0 ;
633656 sycl::detail::ESIMDDeviceInterface *I =
634657 sycl::detail::getESIMDDeviceInterface ();
635658
@@ -639,7 +662,8 @@ __esimd_gather_masked_scaled2(SurfIndAliasTy surf_ind, uint32_t global_offset,
639662 char *SlmBase = I->__cm_emu_get_slm_ptr ();
640663 for (int idx = 0 ; idx < N; ++idx) {
641664 if (pred[idx]) {
642- Ty *addr = reinterpret_cast <Ty *>(offsets[idx] + SlmBase);
665+ RestoredTy *addr =
666+ reinterpret_cast <RestoredTy *>(offsets[idx] + SlmBase);
643667 retv[idx] = *addr;
644668 }
645669 }
@@ -655,15 +679,21 @@ __esimd_gather_masked_scaled2(SurfIndAliasTy surf_ind, uint32_t global_offset,
655679 std::unique_lock<std::mutex> lock (*mutexLock);
656680 for (int idx = 0 ; idx < N; idx++) {
657681 if (pred[idx]) {
658- Ty *addr = reinterpret_cast <Ty *>(offsets[idx] + readBase);
682+ RestoredTy *addr =
683+ reinterpret_cast <RestoredTy *>(offsets[idx] + readBase);
659684 retv[idx] = *addr;
660685 }
661686 }
662687
663688 // TODO : Optimize
664689 I->cm_fence_ptr ();
665690 }
666- return retv;
691+
692+ if constexpr (OrigSize == 4 ) {
693+ return __ESIMD_DNS::bitcast<Ty, RestoredTy, N>(retv);
694+ } else {
695+ return __ESIMD_DNS::convert_vector<Ty, RestoredTy, N>(retv);
696+ }
667697}
668698#endif // __SYCL_DEVICE_ONLY__
669699
0 commit comments