@@ -223,6 +223,15 @@ static __global__ void add_f32(const float * x, const float * y, float * dst, co
223223 dst[i] = x[i] + y[i];
224224}
225225
226+ static __global__ void add_f16_f32_f16 (const half * x, const float * y, half * dst, const int k) {
227+ const int i = blockDim .x *blockIdx .x + threadIdx .x ;
228+
229+ if (i >= k) {
230+ return ;
231+ }
232+ dst[i] = __hadd (x[i], __float2half (y[i]));
233+ }
234+
226235static __global__ void mul_f32 (const float * x, const float * y, float * dst, const int kx, const int ky) {
227236 const int i = blockDim .x *blockIdx .x + threadIdx .x ;
228237
@@ -1459,6 +1468,11 @@ static void add_f32_cuda(const float * x, const float * y, float * dst, const in
14591468 add_f32<<<num_blocks, CUDA_ADD_BLOCK_SIZE, 0 , stream>>> (x, y, dst, k);
14601469}
14611470
1471+ static void add_f16_f32_f16_cuda (const half * x, const float * y, half * dst, const int k, cudaStream_t stream) {
1472+ const int num_blocks = (k + CUDA_ADD_BLOCK_SIZE - 1 ) / CUDA_ADD_BLOCK_SIZE;
1473+ add_f16_f32_f16<<<num_blocks, CUDA_ADD_BLOCK_SIZE, 0 , stream>>> (x, y, dst, k);
1474+ }
1475+
14621476static void mul_f32_cuda (const float * x, const float * y, float * dst, const int kx, const int ky, cudaStream_t stream) {
14631477 const int num_blocks = (kx + CUDA_MUL_BLOCK_SIZE - 1 ) / CUDA_MUL_BLOCK_SIZE;
14641478 mul_f32<<<num_blocks, CUDA_MUL_BLOCK_SIZE, 0 , stream>>> (x, y, dst, kx, ky);
@@ -1941,15 +1955,21 @@ inline void ggml_cuda_op_add(
19411955 float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
19421956 cudaStream_t & cudaStream_main){
19431957
1944- GGML_ASSERT (src0_ddf_i != nullptr );
1958+ GGML_ASSERT (src0_ddq_i != nullptr || src0_ddf_i != nullptr );
19451959 GGML_ASSERT (src1_ddf_i != nullptr );
19461960 GGML_ASSERT (dst_ddf_i != nullptr );
19471961
19481962 const int64_t ne0 = src0->ne [0 ];
19491963 const int64_t i01_diff = i01_high - i01_low;
19501964
19511965 // compute
1952- add_f32_cuda (src0_ddf_i, src1_ddf_i, dst_ddf_i, ne0*i01_diff, cudaStream_main);
1966+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
1967+ add_f32_cuda (src0_ddf_i, src1_ddf_i, dst_ddf_i, ne0*i01_diff, cudaStream_main);
1968+ } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
1969+ add_f16_f32_f16_cuda ((half *) src0_ddq_i, src1_ddf_i, (half *) dst_ddf_i, ne0*i01_diff, cudaStream_main);
1970+ } else {
1971+ GGML_ASSERT (false );
1972+ }
19531973 CUDA_CHECK (cudaGetLastError ());
19541974
19551975 (void ) src1;
@@ -2547,8 +2567,14 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
25472567}
25482568
25492569void ggml_cuda_add (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
2550- GGML_ASSERT (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
2551- ggml_cuda_op (src0, src1, dst, ggml_cuda_op_add, true , true );
2570+ // ggml_cuda_add permits f16 dst even though this could in theory cause problems with the pointer arithmetic in ggml_cuda_op.
2571+ // Due to flatten_rows == true this does in practice not make a difference however.
2572+ // Better solution would be nice but right now that would require disproportionate changes.
2573+ GGML_ASSERT (
2574+ (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16) &&
2575+ src1->type == GGML_TYPE_F32 &&
2576+ (dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16));
2577+ ggml_cuda_op (src0, src1, dst, ggml_cuda_op_add, false , true );
25522578}
25532579
25542580void ggml_cuda_mul (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@@ -2801,7 +2827,7 @@ void ggml_cuda_free_data(struct ggml_tensor * tensor) {
28012827 delete extra;
28022828}
28032829
2804- void ggml_cuda_assign_buffers_impl (struct ggml_tensor * tensor, bool scratch) {
2830+ void ggml_cuda_assign_buffers_impl (struct ggml_tensor * tensor, bool scratch, bool force_inplace ) {
28052831 if (scratch && g_scratch_size == 0 ) {
28062832 return ;
28072833 }
@@ -2810,23 +2836,24 @@ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch) {
28102836 if (tensor->src0 != nullptr && tensor->src0 ->backend == GGML_BACKEND_CPU) {
28112837 const ggml_op src0_op = tensor->src0 ->op ;
28122838 if (src0_op == GGML_OP_RESHAPE || src0_op == GGML_OP_TRANSPOSE || src0_op == GGML_OP_VIEW) {
2813- ggml_cuda_assign_buffers_impl (tensor->src0 , scratch);
2839+ ggml_cuda_assign_buffers_impl (tensor->src0 , scratch, force_inplace );
28142840 }
28152841 }
28162842 if (tensor->op == GGML_OP_CPY && tensor->src1 ->backend == GGML_BACKEND_CPU) {
2817- ggml_cuda_assign_buffers_impl (tensor->src1 , scratch);
2843+ ggml_cuda_assign_buffers_impl (tensor->src1 , scratch, force_inplace );
28182844 }
28192845
28202846 tensor->backend = GGML_BACKEND_GPU;
28212847 struct ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu;
28222848 memset (extra, 0 , sizeof (*extra));
28232849
28242850 const bool inplace = (tensor->src0 != nullptr && tensor->src0 ->data == tensor->data ) ||
2825- tensor->op == GGML_OP_VIEW;
2851+ tensor->op == GGML_OP_VIEW ||
2852+ force_inplace;
28262853 const size_t size = ggml_nbytes (tensor);
28272854
28282855 CUDA_CHECK (cudaSetDevice (g_main_device));
2829- if (inplace && tensor->src0 ->backend == GGML_BACKEND_GPU) {
2856+ if (inplace && ( tensor->src0 ->backend == GGML_BACKEND_GPU || tensor-> src0 -> backend == GGML_BACKEND_GPU_SPLIT) ) {
28302857 struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu * ) tensor->src0 ->extra ;
28312858 char * src0_ddc = (char *) src0_extra->data_device [g_main_device];
28322859 size_t offset = 0 ;
@@ -2865,11 +2892,15 @@ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch) {
28652892}
28662893
28672894void ggml_cuda_assign_buffers (struct ggml_tensor * tensor) {
2868- ggml_cuda_assign_buffers_impl (tensor, true );
2895+ ggml_cuda_assign_buffers_impl (tensor, true , false );
28692896}
28702897
28712898void ggml_cuda_assign_buffers_no_scratch (struct ggml_tensor * tensor) {
2872- ggml_cuda_assign_buffers_impl (tensor, false );
2899+ ggml_cuda_assign_buffers_impl (tensor, false , false );
2900+ }
2901+
2902+ void ggml_cuda_assign_buffers_force_inplace (struct ggml_tensor * tensor) {
2903+ ggml_cuda_assign_buffers_impl (tensor, false , true );
28732904}
28742905
28752906void ggml_cuda_set_main_device (int main_device) {
0 commit comments