1
1
#include " upscale.cuh"
2
2
3
- static __global__ void upscale_f32 (const float * x, float * dst, const int ne00, const int ne00xne01, const int scale_factor) {
4
- // blockIdx.z: idx of ne02*ne03
5
- // blockIdx.y: idx of ne01*scale_factor, aka ne1
6
- // blockIDx.x: idx of ne00*scale_factor / BLOCK_SIZE
7
- // ne00xne01: ne00 * ne01
8
- int ne0 = ne00 * scale_factor;
9
- int nidx = threadIdx .x + blockIdx .x * blockDim .x ;
10
- if (nidx >= ne0) {
3
+ static __global__ void upscale_f32 (const float * x, float * dst,
4
+ const int nb00, const int nb01, const int nb02, const int nb03,
5
+ const int ne10, const int ne11, const int ne12, const int ne13,
6
+ const float sf0, const float sf1, const float sf2, const float sf3) {
7
+ int index = threadIdx .x + blockIdx .x * blockDim .x ;
8
+ if (index >= ne10 * ne11 * ne12 * ne13) {
11
9
return ;
12
10
}
13
- // operation
14
- int i00 = nidx / scale_factor ;
15
- int i01 = blockIdx . y / scale_factor ;
16
- int offset_src =
17
- i00 +
18
- i01 * ne00 +
19
- blockIdx . z * ne00xne01 ;
20
- int offset_dst =
21
- nidx +
22
- blockIdx . y * ne0 +
23
- blockIdx . z * ne0 * gridDim . y ;
24
- dst[offset_dst ] = x[offset_src] ;
11
+
12
+ int i10 = index % ne10 ;
13
+ int i11 = (index / ne10) % ne11 ;
14
+ int i12 = (index / (ne10 * ne11)) % ne12;
15
+ int i13 = (index / (ne10 * ne11 * ne12)) % ne13;
16
+
17
+ int i00 = i10 / sf0 ;
18
+ int i01 = i11 / sf1;
19
+ int i02 = i12 / sf2;
20
+ int i03 = i13 / sf3;
21
+
22
+ dst[index ] = *( float *)(( char *)x + i03 * nb03 + i02 * nb02 + i01 * nb01 + i00 * nb00) ;
25
23
}
26
24
27
- static void upscale_f32_cuda (const float * x, float * dst, const int ne00, const int ne01, const int ne02, const int ne03,
28
- const int scale_factor, cudaStream_t stream) {
29
- int ne0 = (ne00 * scale_factor);
30
- int num_blocks = (ne0 + CUDA_UPSCALE_BLOCK_SIZE - 1 ) / CUDA_UPSCALE_BLOCK_SIZE;
31
- dim3 gridDim (num_blocks, (ne01 * scale_factor), ne02*ne03);
32
- upscale_f32<<<gridDim , CUDA_UPSCALE_BLOCK_SIZE, 0 , stream>>> (x, dst, ne00, ne00 * ne01, scale_factor);
25
+ static void upscale_f32_cuda (const float * x, float * dst,
26
+ const int nb00, const int nb01, const int nb02, const int nb03,
27
+ const int ne10, const int ne11, const int ne12, const int ne13,
28
+ const float sf0, const float sf1, const float sf2, const float sf3,
29
+ cudaStream_t stream) {
30
+ int dst_size = ne10 * ne11 * ne12 * ne13;
31
+ int num_blocks = (dst_size + CUDA_UPSCALE_BLOCK_SIZE - 1 ) / CUDA_UPSCALE_BLOCK_SIZE;
32
+
33
+ upscale_f32<<<num_blocks, CUDA_UPSCALE_BLOCK_SIZE,0 ,stream>>> (x, dst, nb00, nb01, nb02, nb03, ne10, ne11, ne12, ne13, sf0, sf1, sf2, sf3);
33
34
}
34
35
35
36
void ggml_cuda_op_upscale (ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
@@ -39,10 +40,12 @@ void ggml_cuda_op_upscale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
39
40
cudaStream_t stream = ctx.stream ();
40
41
41
42
GGML_ASSERT (src0->type == GGML_TYPE_F32);
42
- GGML_ASSERT (dst->type == GGML_TYPE_F32);
43
- GGML_ASSERT (src0->ne [3 ] == 1 && dst->ne [3 ] == 1 ); // just 3D tensors
43
+ GGML_ASSERT ( dst->type == GGML_TYPE_F32);
44
44
45
- const int scale_factor = dst->op_params [0 ];
45
+ const float sf0 = (float )dst->ne [0 ]/src0->ne [0 ];
46
+ const float sf1 = (float )dst->ne [1 ]/src0->ne [1 ];
47
+ const float sf2 = (float )dst->ne [2 ]/src0->ne [2 ];
48
+ const float sf3 = (float )dst->ne [3 ]/src0->ne [3 ];
46
49
47
- upscale_f32_cuda (src0_d, dst_d, src0->ne [0 ], src0->ne [1 ], src0->ne [2 ], src0->ne [3 ], scale_factor , stream);
50
+ upscale_f32_cuda (src0_d, dst_d, src0->nb [0 ], src0->nb [1 ], src0->nb [2 ], src0->nb [ 3 ], dst-> ne [0 ], dst-> ne [ 1 ], dst-> ne [ 2 ], dst-> ne [ 3 ], sf0, sf1, sf2, sf3 , stream);
48
51
}
0 commit comments