@@ -6893,6 +6893,8 @@ static void ggml_cuda_op_mul_mat(
68936893 int64_t row_low[GGML_CUDA_MAX_DEVICES];
68946894 int64_t row_high[GGML_CUDA_MAX_DEVICES];
68956895
6896+ int used_devices = 0 ;
6897+
68966898 for (int64_t id = 0 ; id < g_device_count; ++id) {
68976899 // by default, use all rows
68986900 row_low[id] = 0 ;
@@ -6920,6 +6922,8 @@ static void ggml_cuda_op_mul_mat(
69206922 continue ;
69216923 }
69226924
6925+ used_devices++;
6926+
69236927 const bool src1_on_device = src1->backend == GGML_BACKEND_GPU && id == g_main_device;
69246928 const bool dst_on_device = dst->backend == GGML_BACKEND_GPU && id == g_main_device;
69256929
@@ -6958,12 +6962,12 @@ static void ggml_cuda_op_mul_mat(
69586962
69596963 // if multiple devices are used they need to wait for the main device
69606964 // here an event is recorded that signals that the main device has finished calculating the input data
6961- if (split && g_device_count > 1 ) {
6965+ if (split && used_devices > 1 ) {
69626966 CUDA_CHECK (ggml_cuda_set_device (g_main_device));
69636967 CUDA_CHECK (cudaEventRecord (src0_extra->events [g_main_device][0 ], g_cudaStreams[g_main_device][0 ]));
69646968 }
69656969
6966- const int64_t src1_col_stride = split && g_device_count > 1 ? MUL_MAT_SRC1_COL_STRIDE : ne11;
6970+ const int64_t src1_col_stride = split && used_devices > 1 ? MUL_MAT_SRC1_COL_STRIDE : ne11;
69676971 for (int64_t src1_col_0 = 0 ; src1_col_0 < ne11; src1_col_0 += src1_col_stride) {
69686972 const int64_t is = split ? (src1_col_0/src1_col_stride) % MAX_STREAMS : 0 ;
69696973 const int64_t src1_ncols = src1_col_0 + src1_col_stride > ne11 ? ne11 - src1_col_0 : src1_col_stride;
@@ -7079,6 +7083,9 @@ static void ggml_cuda_op_mul_mat(
70797083 }
70807084
70817085 for (int64_t id = 0 ; id < g_device_count; ++id) {
7086+ if ((!split && id != g_main_device) || row_low[id] == row_high[id]) {
7087+ continue ;
7088+ }
70827089 CUDA_CHECK (ggml_cuda_set_device (id));
70837090
70847091 // free buffers again when done
@@ -7103,6 +7110,9 @@ static void ggml_cuda_op_mul_mat(
71037110
71047111 CUDA_CHECK (ggml_cuda_set_device (g_main_device));
71057112 for (int64_t id = 0 ; id < g_device_count; ++id) {
7113+ if (row_low[id] == row_high[id]) {
7114+ continue ;
7115+ }
71067116 for (int64_t is = 0 ; is < is_max; ++is) {
71077117 CUDA_CHECK (cudaStreamWaitEvent (g_cudaStreams[g_main_device][0 ], src0_extra->events [id][is], 0 ));
71087118 }
@@ -7400,7 +7410,7 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
74007410
74017411static void ggml_cuda_mul_mat (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
74027412 const bool all_on_device =
7403- (src0->backend == GGML_BACKEND_GPU) &&
7413+ (src0->backend == GGML_BACKEND_GPU || src0-> backend == GGML_BACKEND_GPU_SPLIT ) &&
74047414 (src1->backend == GGML_BACKEND_GPU) &&
74057415 ( dst->backend == GGML_BACKEND_GPU);
74067416
0 commit comments