@@ -586,17 +586,42 @@ static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const half2 v,
586
586
#endif // defined(GGML_USE_HIP) && (defined(RDNA2) || defined(RDNA3) || defined(RDNA4) || defined(GCN5) || defined(CDNA))
587
587
}
588
588
589
+ static __device__ __forceinline__ void ggml_cuda_mad (half2 & acc, const half2 v, const half2 u) {
590
+ #ifdef FAST_FP16_AVAILABLE
591
+ acc += v*u;
592
+ #else
593
+ const float2 tmpv = __half22float2 (v);
594
+ const float2 tmpu = __half22float2 (u);
595
+ float2 tmpacc = __half22float2 (acc);
596
+ tmpacc.x += tmpv.x * tmpu.x ;
597
+ tmpacc.y += tmpv.y * tmpu.y ;
598
+ acc = make_half2 (tmpacc.x , tmpacc.y );
599
+ #endif // FAST_FP16_AVAILABLE
600
+ }
601
+
589
602
// Aligned memory transfers of 8/16 bytes can be faster than 2 transfers with 4 bytes, especially on AMD.
590
- template <int nbytes>
603
+ template <int nbytes, int alignment = 0 >
591
604
static __device__ __forceinline__ void ggml_cuda_memcpy_1 (void * __restrict__ dst, const void * __restrict__ src) {
592
- if constexpr (nbytes == 4 ) {
593
- *(int *) dst = *(const int *) src;
594
- } else if constexpr (nbytes == 8 ) {
595
- *(int2 *) dst = *(const int2 *) src;
596
- } else if constexpr (nbytes == 16 ) {
597
- *(int4 *) dst = *(const int4 *) src;
598
- } else {
599
- static_assert (nbytes == 0 && nbytes == -1 , " bad nbytes" );
605
+ if constexpr (alignment != 0 ) {
606
+ static_assert (nbytes % alignment == 0 , " bad alignment" );
607
+ }
608
+ constexpr int nb_per_cpy = alignment == 0 ? nbytes : alignment;
609
+
610
+ #pragma unroll
611
+ for (int i = 0 ; i < nbytes/nb_per_cpy; ++i) {
612
+ if constexpr (nb_per_cpy == 1 ) {
613
+ ((char *) dst)[i] = ((const char *) src)[i];
614
+ } else if constexpr (nb_per_cpy == 2 ) {
615
+ ((short *) dst)[i] = ((const short *) src)[i];
616
+ } else if constexpr (nb_per_cpy == 4 ) {
617
+ ((int *) dst)[i] = ((const int *) src)[i];
618
+ } else if constexpr (nb_per_cpy == 8 ) {
619
+ ((int2 *) dst)[i] = ((const int2 *) src)[i];
620
+ } else if constexpr (nb_per_cpy == 16 ) {
621
+ ((int4 *) dst)[i] = ((const int4 *) src)[i];
622
+ } else {
623
+ static_assert (nbytes == 0 && nbytes == -1 , " bad nbytes" );
624
+ }
600
625
}
601
626
}
602
627
0 commit comments