@@ -974,9 +974,16 @@ kernel void kernel_mul(
974974 device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1 [0 ];
975975 device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs ;
976976
977- for (int i0 = tpitg.x ; i0 < args.ne0 ; i0 += ntg.x ) {
978- const int i10 = i0%args.ne10 ;
979- *((device float *)(dst_ptr + i0*args.nb0 )) = *((device float *)(src0_ptr + i0*args.nb00 )) * *((device float *)(src1_ptr + i10*args.nb10 ));
977+ if (args.ne10 == 1 ) {
978+ const float x = *((device float *)(src1_ptr));
979+ for (int i0 = tpitg.x ; i0 < args.ne0 ; i0 += ntg.x ) {
980+ *((device float *)(dst_ptr + i0*args.nb0 )) = *((device float *)(src0_ptr + i0*args.nb00 )) * x;
981+ }
982+ } else {
983+ for (int i0 = tpitg.x ; i0 < args.ne0 ; i0 += ntg.x ) {
984+ const int i10 = i0%args.ne10 ;
985+ *((device float *)(dst_ptr + i0*args.nb0 )) = *((device float *)(src0_ptr + i0*args.nb00 )) * *((device float *)(src1_ptr + i10*args.nb10 ));
986+ }
980987 }
981988}
982989
@@ -1000,9 +1007,16 @@ kernel void kernel_div(
10001007 device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1 [0 ];
10011008 device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs ;
10021009
1003- for (int i0 = tpitg.x ; i0 < args.ne0 ; i0 += ntg.x ) {
1004- const int i10 = i0%args.ne10 ;
1005- *((device float *)(dst_ptr + i0*args.nb0 )) = *((device float *)(src0_ptr + i0*args.nb00 )) / *((device float *)(src1_ptr + i10*args.nb10 ));
1010+ if (args.ne10 == 1 ) {
1011+ const float x = 1 .0f / *((device float *)(src1_ptr));
1012+ for (int i0 = tpitg.x ; i0 < args.ne0 ; i0 += ntg.x ) {
1013+ *((device float *)(dst_ptr + i0*args.nb0 )) = *((device float *)(src0_ptr + i0*args.nb00 )) * x;
1014+ }
1015+ } else {
1016+ for (int i0 = tpitg.x ; i0 < args.ne0 ; i0 += ntg.x ) {
1017+ const int i10 = i0%args.ne10 ;
1018+ *((device float *)(dst_ptr + i0*args.nb0 )) = *((device float *)(src0_ptr + i0*args.nb00 )) / *((device float *)(src1_ptr + i10*args.nb10 ));
1019+ }
10061020 }
10071021}
10081022
0 commit comments