1+ #include < sycl/sycl.hpp>
2+ #include " wkv6.hpp"
3+
4+ constexpr int WKV_BLOCK_SIZE = 64 ; // Matching CUDA_WKV_BLOCK_SIZE
5+
6+ // Helper function for the main kernel
7+ static void rwkv_wkv_f32_kernel (
8+ const int B, const int T, const int C, const int H,
9+ const float * k, const float * v, const float * r,
10+ const float * tf, const float * td, const float * s,
11+ float * dst, const sycl::nd_item<3 >& item_ct1, float * shared_mem) {
12+
13+ const int tid = item_ct1.get_local_id (2 );
14+ const int bid = item_ct1.get_group (2 );
15+
16+ const int head_size = WKV_BLOCK_SIZE;
17+ const int batch_i = bid / H;
18+ const int head_i = bid % H;
19+ const int state_size = C * head_size;
20+ const int n_seq_tokens = T / B;
21+
22+ // Set up shared memory pointers
23+ float * _k = shared_mem;
24+ float * _r = _k + head_size;
25+ float * _tf = _r + head_size;
26+ float * _td = _tf + head_size;
27+
28+ // Local state array
29+ float state[WKV_BLOCK_SIZE];
30+
31+ // Load initial state
32+ #pragma unroll
33+ for (int i = 0 ; i < head_size; i++) {
34+ state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid];
35+ }
36+
37+ // Sync threads before shared memory operations
38+ item_ct1.barrier (sycl::access::fence_space::local_space);
39+
40+ // Load time-mixing parameters
41+ _tf[tid] = tf[head_i * head_size + tid];
42+ item_ct1.barrier (sycl::access::fence_space::local_space);
43+
44+ // Main sequence processing loop
45+ for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid;
46+ t < (batch_i + 1 ) * n_seq_tokens * C + head_i * head_size + tid;
47+ t += C) {
48+
49+ item_ct1.barrier (sycl::access::fence_space::local_space);
50+
51+ // Load current timestep data to shared memory
52+ _k[tid] = k[t];
53+ _r[tid] = r[t];
54+ _td[tid] = td[t];
55+
56+ item_ct1.barrier (sycl::access::fence_space::local_space);
57+
58+ const float _v = v[t];
59+ float y = 0 ;
60+
61+ // Process in chunks of 4 for better vectorization
62+ #pragma unroll
63+ for (int j = 0 ; j < head_size; j += 4 ) {
64+ // Load data in vec4 chunks
65+ sycl::float4 k4 (_k[j], _k[j+1 ], _k[j+2 ], _k[j+3 ]);
66+ sycl::float4 r4 (_r[j], _r[j+1 ], _r[j+2 ], _r[j+3 ]);
67+ sycl::float4 tf4 (_tf[j], _tf[j+1 ], _tf[j+2 ], _tf[j+3 ]);
68+ sycl::float4 td4 (_td[j], _td[j+1 ], _td[j+2 ], _td[j+3 ]);
69+ sycl::float4 s4 (state[j], state[j+1 ], state[j+2 ], state[j+3 ]);
70+
71+ // Compute key-value product
72+ sycl::float4 kv4 = k4 * _v;
73+
74+ // Accumulate weighted sum
75+ y += sycl::dot (r4, tf4 * kv4 + s4);
76+
77+ // Update state
78+ s4 = s4 * td4 + kv4;
79+
80+ // Store updated state
81+ state[j] = s4.x ();
82+ state[j+1 ] = s4.y ();
83+ state[j+2 ] = s4.z ();
84+ state[j+3 ] = s4.w ();
85+ }
86+
87+ dst[t] = y;
88+ }
89+
90+ // Save final state
91+ #pragma unroll
92+ for (int i = 0 ; i < head_size; i++) {
93+ dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i];
94+ }
95+ }
96+
97+ void ggml_sycl_op_rwkv_wkv6 (ggml_backend_sycl_context& ctx, const ggml_tensor* src0,
98+ const ggml_tensor* src1, ggml_tensor* dst) {
99+
100+ const float * k_d = (const float *)dst->src [0 ]->data ;
101+ const float * v_d = (const float *)dst->src [1 ]->data ;
102+ const float * r_d = (const float *)dst->src [2 ]->data ;
103+ const float * tf_d = (const float *)dst->src [3 ]->data ;
104+ const float * td_d = (const float *)dst->src [4 ]->data ;
105+ const float * s_d = (const float *)dst->src [5 ]->data ;
106+ float * dst_d = (float *)dst->data ;
107+
108+ const int64_t B = dst->src [5 ]->ne [1 ];
109+ const int64_t T = dst->src [0 ]->ne [3 ];
110+ const int64_t C = dst->ne [0 ];
111+ const int64_t H = dst->src [0 ]->ne [2 ];
112+
113+ GGML_ASSERT (dst->src [5 ]->type == GGML_TYPE_F32);
114+ GGML_ASSERT (C % H == 0 );
115+ GGML_ASSERT (C / H == WKV_BLOCK_SIZE);
116+
117+ dpct::queue_ptr stream = ctx.stream ();
118+
119+ // Calculate execution configuration
120+ const size_t shared_mem_size = WKV_BLOCK_SIZE * 4 * sizeof (float ); // For k, r, tf, td
121+ sycl::range<3 > block_dims (1 , 1 , C / H);
122+ sycl::range<3 > grid_dims (1 , 1 , B * H);
123+
124+ // Submit kernel
125+ stream->submit ([&](sycl::handler& cgh) {
126+ sycl::local_accessor<float , 1 > shared_mem_acc (shared_mem_size, cgh);
127+
128+ cgh.parallel_for (
129+ sycl::nd_range<3 >(grid_dims * block_dims, block_dims),
130+ [=](sycl::nd_item<3 > item_ct1) {
131+ rwkv_wkv_f32_kernel (
132+ B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d,
133+ item_ct1, shared_mem_acc.get_pointer ()
134+ );
135+ });
136+ });
137+ }
0 commit comments