Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ggml/src/ggml-sycl/backend.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
#include "roll.hpp"
#include "rope.hpp"
#include "set_rows.hpp"
#include "ssm_conv.hpp"
#include "softmax.hpp"
#include "tsembd.hpp"
#include "wkv.hpp"
Expand Down
7 changes: 7 additions & 0 deletions ggml/src/ggml-sycl/ggml-sycl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
#include "ggml-sycl/getrows.hpp"
#include "ggml-sycl/repeat_back.hpp"
#include "ggml-sycl/quantize.hpp"
#include "ggml-sycl/ssm_conv.hpp"
#include "ggml.h"

static bool g_sycl_loaded = false;
Expand Down Expand Up @@ -3921,6 +3922,8 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
case GGML_OP_GATED_LINEAR_ATTN:
ggml_sycl_op_gated_linear_attn(ctx, dst);
break;
case GGML_OP_SSM_CONV:
ggml_sycl_ssm_conv(ctx, dst);
case GGML_OP_ROLL:
ggml_sycl_roll(ctx, dst);
break;
Expand Down Expand Up @@ -4602,6 +4605,10 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_RWKV_WKV7:
case GGML_OP_GATED_LINEAR_ATTN:
return true;
case GGML_OP_SSM_CONV:
return op->type == GGML_TYPE_F32 &&
op->src[0]->type == GGML_TYPE_F32 &&
op->src[1]->type == GGML_TYPE_F32;
case GGML_OP_ROLL:
return op->type == GGML_TYPE_F32;
case GGML_OP_ARANGE:
Expand Down
127 changes: 127 additions & 0 deletions ggml/src/ggml-sycl/ssm_conv.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
#include "ssm_conv.hpp"
#include "common.hpp"

#include <cstdio>

using namespace sycl;

static void kernel_ssm_conv(
queue &q,
const float *src_data,
const float *weights,
float *dst_data,
int d_conv,
int d_inner,
int n_t,
int n_s,
int ncs __attribute__((unused)),
int src_stride_inner,
int src_stride_seq,
int dst_stride_token,
int dst_stride_seq
) {
const size_t total_work = static_cast<size_t>(d_inner) * static_cast<size_t>(n_t) * static_cast<size_t>(n_s);
const size_t work_group_size = 256;
const size_t num_work_groups = (total_work + work_group_size - 1) / work_group_size;

const range<1> global_range(num_work_groups * work_group_size);
const range<1> local_range(work_group_size);

q.submit([&](handler &h) {
h.parallel_for(
nd_range<1>(global_range, local_range),
[=](nd_item<1> item) {
const size_t idx = item.get_global_id(0);
if (idx >= total_work) {
return;
}

const int channel = static_cast<int>(idx % d_inner);
const int token = static_cast<int>((idx / d_inner) % n_t);
const int seq = static_cast<int>(idx / (static_cast<size_t>(d_inner) * static_cast<size_t>(n_t)));

const float *s = src_data
+ static_cast<size_t>(seq) * static_cast<size_t>(src_stride_seq)
+ static_cast<size_t>(channel) * static_cast<size_t>(src_stride_inner)
+ static_cast<size_t>(token);

const float *c = weights + static_cast<size_t>(channel) * static_cast<size_t>(d_conv);

float sumf = 0.0f;
for (int i0 = 0; i0 < d_conv; ++i0) {
sumf += s[i0] * c[i0];
}

const size_t dst_idx =
static_cast<size_t>(seq) * static_cast<size_t>(dst_stride_seq) +
static_cast<size_t>(token) * static_cast<size_t>(dst_stride_token) +
static_cast<size_t>(channel);

dst_data[dst_idx] = sumf;
}
);
});
}

void ggml_sycl_ssm_conv(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
ggml_tensor * src0 = dst->src[0];
ggml_tensor * src1 = dst->src[1];

GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);

const int d_conv = src1->ne[0];
const int ncs = src0->ne[0];
const int d_inner = src0->ne[1];
const int n_t = dst->ne[1];
const int n_s = dst->ne[2];

GGML_ASSERT(src0->ne[0] == d_conv - 1 + n_t);
GGML_ASSERT(src0->ne[1] == d_inner);
GGML_ASSERT(src1->ne[1] == d_inner);

GGML_ASSERT(dst->ne[0] == d_inner);
GGML_ASSERT(dst->ne[1] == n_t);
GGML_ASSERT(dst->ne[2] == n_s);

GGML_ASSERT(src0->nb[0] == sizeof(float));
GGML_ASSERT(src1->nb[0] == sizeof(float));

GGML_ASSERT(src0->nb[1] == src0->ne[0] * static_cast<int>(sizeof(float)));

const int src_stride_inner = ncs;
const int src_stride_seq = ncs * d_inner;
const int dst_stride_token = d_inner;
const int dst_stride_seq = d_inner * n_t;

try {
queue *q = ctx.stream();

const float *src_data = static_cast<const float *>(src0->data);
const float *weights = static_cast<const float *>(src1->data);
float *dst_data = static_cast<float *>(dst->data);

GGML_ASSERT(src_data && weights && dst_data);

kernel_ssm_conv(
*q,
src_data,
weights,
dst_data,
d_conv,
d_inner,
n_t,
n_s,
ncs,
src_stride_inner,
src_stride_seq,
dst_stride_token,
dst_stride_seq
);

} catch (const std::exception &e) {
std::fprintf(stderr, "[SYCL-SSM_CONV] ERROR: %s\n", e.what());
throw;
}
}
5 changes: 5 additions & 0 deletions ggml/src/ggml-sycl/ssm_conv.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#pragma once

#include "common.hpp"

void ggml_sycl_ssm_conv(ggml_backend_sycl_context & ctx, ggml_tensor * dst);