Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 23 additions & 29 deletions gpt_oss/metal/benchmark/end-to-end.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include <gpt-oss.h>
#include <internal/model.h>

#include <array>
#include <cstdint>
#include <cstddef>
#include <format>
Expand All @@ -12,7 +13,7 @@
#include <benchmark/benchmark.h>


constexpr std::uint32_t num_generated_tokens = 100;
constexpr std::uint32_t kNumGeneratedTokens = 100;


static void end2end(benchmark::State& state, const char* env_var_name) {
Expand All @@ -30,14 +31,6 @@ static void end2end(benchmark::State& state, const char* env_var_name) {
}
std::unique_ptr<std::remove_pointer_t<gptoss_model_t>, decltype(&gptoss_model_release)> model(model_ptr, gptoss_model_release);

gptoss_tokenizer_t tokenizer_ptr = nullptr;
status = gptoss_model_get_tokenizer(model.get(), &tokenizer_ptr);
if (status != gptoss_status_success) {
state.SkipWithError("failed to retrieve Tokenizer");
return;
}
std::unique_ptr<std::remove_pointer_t<gptoss_tokenizer_t>, decltype(&gptoss_tokenizer_release)> tokenizer(tokenizer_ptr, gptoss_tokenizer_release);

gptoss_context_t context_ptr = nullptr;
status = gptoss_context_create(model.get(), /*context_lenght=*/0, &context_ptr);
if (status != gptoss_status_success) {
Expand All @@ -60,50 +53,51 @@ static void end2end(benchmark::State& state, const char* env_var_name) {
state.SkipWithError("failed to prefill Context object");
return;
}

const std::size_t num_kvcache_tokens = context->num_kv_tokens;

std::uint64_t rng_seed = 0;
for (std::uint32_t i = 0; i < 3; i++) {
const std::uint64_t current_rng_seed = rng_seed++;
context->num_kv_tokens = num_prompt_tokens;
context->num_tokens = num_prompt_tokens;

for (std::uint32_t n = 0; n < num_generated_tokens; n++) {
std::uint32_t predicted_token = std::numeric_limits<std::uint32_t>::max();
status = gptoss_context_sample(context.get(), /*temperature=*/1.0f, /*rng_state=*/rng_seed++, &predicted_token);
std::array<std::uint32_t, kNumGeneratedTokens> tokens;
std::size_t num_generated_tokens = 0;
do {
std::size_t num_current_generated_tokens = 0;
status = gptoss_context_sample(context.get(), /*temperature=*/1.0f, /*rng_state=*/current_rng_seed,
/*max_tokens=*/kNumGeneratedTokens - num_generated_tokens, tokens.data(), &num_current_generated_tokens);
if (status != gptoss_status_success) {
state.SkipWithError("failed to sample from the Context object");
return;
}
status = gptoss_context_append_tokens(context.get(), 1, &predicted_token);
if (status != gptoss_status_success) {
state.SkipWithError(std::format("failed to append token {} to the Context object", predicted_token));
return;
}
}
num_generated_tokens += num_current_generated_tokens;
} while (num_generated_tokens < kNumGeneratedTokens);
}

for (auto _ : state) {
const std::uint64_t current_rng_seed = rng_seed++;
context->num_kv_tokens = num_prompt_tokens;
context->num_tokens = num_prompt_tokens;

for (std::uint32_t n = 0; n < num_generated_tokens; n++) {
std::uint32_t predicted_token = std::numeric_limits<std::uint32_t>::max();
status = gptoss_context_sample(context.get(), /*temperature=*/1.0f, /*rng_state=*/rng_seed++, &predicted_token);
std::array<std::uint32_t, kNumGeneratedTokens> tokens;
std::size_t num_generated_tokens = 0;
do {
std::size_t num_current_generated_tokens = 0;
status = gptoss_context_sample(context.get(), /*temperature=*/1.0f, /*rng_state=*/current_rng_seed,
/*max_tokens=*/kNumGeneratedTokens - num_generated_tokens, tokens.data(), &num_current_generated_tokens);
if (status != gptoss_status_success) {
state.SkipWithError("failed to sample from the Context object");
return;
}
status = gptoss_context_append_tokens(context.get(), 1, &predicted_token);
if (status != gptoss_status_success) {
state.SkipWithError(std::format("failed to append token {} to the Context object", predicted_token));
return;
}
}
num_generated_tokens += num_current_generated_tokens;
} while (num_generated_tokens < kNumGeneratedTokens);
}

state.counters["generations"] =
benchmark::Counter(state.iterations(), benchmark::Counter::kIsRate);
state.counters["tokens"] =
benchmark::Counter(state.iterations() * num_generated_tokens, benchmark::Counter::kIsRate);
benchmark::Counter(state.iterations() * kNumGeneratedTokens, benchmark::Counter::kIsRate);
}

BENCHMARK_CAPTURE(end2end, gpt_oss_20b, "GPT_OSS_20B_PATH")
Expand Down
4 changes: 4 additions & 0 deletions gpt_oss/metal/benchmark/f32-bf16w-rmsnorm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ static void f32_bf16w_rnsnorm(benchmark::State& state) {
Buffer input_buffer{device, num_tokens * num_channels * sizeof(float)};
Buffer weight_buffer{device, num_channels * sizeof(gptoss_bfloat16)};
Buffer output_buffer{device, num_tokens * num_channels * sizeof(float)};
Buffer control_buffer{device, sizeof(gptoss_control)};
std::memset(control_buffer.ptr(), 0, sizeof(gptoss_control));

{
CommandBuffer command_buffer{command_queue};
Expand Down Expand Up @@ -69,6 +71,8 @@ static void f32_bf16w_rnsnorm(benchmark::State& state) {
/*weight_offset=*/0,
output_buffer.handle(),
/*output_offset=*/0,
control_buffer.handle(),
/*control_offset=*/0,
num_tokens,
num_channels,
kEpsilon),
Expand Down
4 changes: 3 additions & 1 deletion gpt_oss/metal/include/gpt-oss/functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,9 @@ enum gptoss_status GPTOSS_ABI gptoss_context_sample(
gptoss_context_t context,
float temperature,
uint64_t seed,
uint32_t* token_out);
size_t max_tokens,
uint32_t* tokens_out,
size_t* num_tokens_out);

/*
* Increments a Context object's reference count.
Expand Down
57 changes: 41 additions & 16 deletions gpt_oss/metal/python/context.c
Original file line number Diff line number Diff line change
Expand Up @@ -120,25 +120,54 @@ static PyObject* PyGPTOSSContext_process(PyGPTOSSContext* self) {
}

static PyObject* PyGPTOSSContext_sample(PyGPTOSSContext* self, PyObject* args, PyObject* kwargs) {
static char *kwlist[] = {"temperature", "seed", NULL};
static char *kwlist[] = {"max_output_tokens", "temperature", "seed", NULL};
PyObject* token_list_obj = NULL;
uint32_t* token_ptr = NULL;

unsigned int max_output_tokens = 0;
unsigned long long seed = 0;
float temperature = 1.0f;
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|$fK", kwlist,
&temperature, &seed))
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "I|$fK", kwlist,
&max_output_tokens, &temperature, &seed))
{
return NULL;
}

uint32_t token_out = UINT32_MAX;
enum gptoss_status status = gptoss_context_sample(
self->handle, temperature, (uint64_t) seed, &token_out);
token_ptr = (uint32_t*) PyMem_Malloc(max_output_tokens * sizeof(uint32_t));
if (token_ptr == NULL) {
goto error;
}

size_t num_tokens = 0;
const enum gptoss_status status = gptoss_context_sample(
self->handle, temperature, (uint64_t) seed,
(size_t) max_output_tokens, token_ptr, &num_tokens);
if (status != gptoss_status_success) {
// TODO: set exception
return NULL;
goto error;
}

return PyLong_FromUnsignedLong((unsigned long) token_out);
token_list_obj = PyList_New((Py_ssize_t) num_tokens);
if (token_list_obj == NULL) {
goto error;
}

for (size_t t = 0; t < num_tokens; t++) {
PyObject* token_obj = PyLong_FromUnsignedLong((unsigned long) token_ptr[t]);
if (token_obj == NULL) {
goto error;
}

PyList_SET_ITEM(token_list_obj, (Py_ssize_t) t, token_obj);
}

PyMem_Free(token_ptr);
return token_list_obj;

error:
PyMem_Free(token_ptr);
Py_XDECREF(token_list_obj);
return NULL;
}

static PyObject* PyGPTOSSContext_reset(PyGPTOSSContext* self) {
Expand All @@ -155,7 +184,7 @@ static PyMethodDef PyGPTOSSContext_methods[] = {
{"__copy__", (PyCFunction) PyGPTOSSContext_copy, METH_NOARGS, "Create a copy of the Context"},
{"append", (PyCFunction) PyGPTOSSContext_append, METH_O, "Append bytes to the Context"},
{"process", (PyCFunction) PyGPTOSSContext_process, METH_NOARGS, "Process tokens in the Context"},
{"sample", (PyCFunction) PyGPTOSSContext_sample, METH_VARARGS | METH_KEYWORDS, "Sample token prediction from the Context"},
{"sample", (PyCFunction) PyGPTOSSContext_sample, METH_VARARGS | METH_KEYWORDS, "Sample token predictions from the Context"},
{"reset", (PyCFunction) PyGPTOSSContext_reset, METH_NOARGS, "Discard the content of the Context"},
{NULL},
};
Expand Down Expand Up @@ -184,7 +213,6 @@ static PyObject* PyGPTOSSContext_get_max_tokens(PyGPTOSSContext* self, void* clo

static PyObject* PyGPTOSSContext_get_tokens(PyGPTOSSContext* self, void* closure) {
PyObject* token_list_obj = NULL;
PyObject* token_obj = NULL;
uint32_t* token_ptr = NULL;

size_t num_tokens = 0;
Expand All @@ -210,22 +238,19 @@ static PyObject* PyGPTOSSContext_get_tokens(PyGPTOSSContext* self, void* closure
}

for (size_t t = 0; t < num_tokens; t++) {
token_obj = PyLong_FromUnsignedLong((unsigned long) token_ptr[t]);
PyObject* token_obj = PyLong_FromUnsignedLong((unsigned long) token_ptr[t]);
if (token_obj == NULL) {
goto error;
}
if (PyList_SetItem(token_list_obj, (Py_ssize_t) t, token_obj) < 0) {
goto error;
}
token_obj = NULL; // PyList_SetItem stole the reference

PyList_SET_ITEM(token_list_obj, (Py_ssize_t) t, token_obj);
}

PyMem_Free(token_ptr);
return token_list_obj;

error:
PyMem_Free(token_ptr);
Py_XDECREF(token_obj);
Py_XDECREF(token_list_obj);
return NULL;
}
Expand Down
4 changes: 4 additions & 0 deletions gpt_oss/metal/source/accumulate.metal
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,15 @@ kernel void gptoss_f32_accumulate_e4(
const device float4* input [[ buffer(1) ]],
const device gptoss_expert_prediction* expert [[ buffer(2) ]],
device float4* output [[ buffer(3) ]],
const device gptoss_control* control [[ buffer(4) ]],
uint2 gid [[threadgroup_position_in_grid]],
uint tid [[thread_index_in_threadgroup]],
uint2 threadgroup_size [[ threads_per_threadgroup ]])
{
const uint num_active_experts = 4;
if (control->abort != 0) {
return;
}

const uint num_vecs_per_threadgroup = args.num_vecs_per_threadgroup;
const uint threadgroup_start = gid.x * num_vecs_per_threadgroup;
Expand Down
Loading