Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 92 additions & 1 deletion tests/test-backend-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include <memory>
#include <random>
#include <regex>
#include <set>
#include <string>
#include <string_view>
#include <thread>
Expand Down Expand Up @@ -6629,8 +6630,90 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
GGML_ABORT("fatal error");
}

static void list_all_ops() {
printf("GGML operations:\n");
std::set<std::string> all_ops;

for (int i = 1; i < GGML_OP_COUNT; i++) {
all_ops.insert(ggml_op_name((enum ggml_op)i));
}
for (int i = 0; i < GGML_UNARY_OP_COUNT; i++) {
all_ops.insert(ggml_unary_op_name((enum ggml_unary_op)i));
}
for (int i = 0; i < GGML_GLU_OP_COUNT; i++) {
all_ops.insert(ggml_glu_op_name((enum ggml_glu_op)i));
}
for (const auto & op : all_ops) {
printf(" %s\n", op.c_str());
}
printf("\nTotal: %zu operations\n", all_ops.size());
}

static void show_test_coverage() {
std::set<std::string> all_ops;
for (int i = 1; i < GGML_OP_COUNT; i++) {
all_ops.insert(ggml_op_name((enum ggml_op)i));
}
for (int i = 0; i < GGML_UNARY_OP_COUNT; i++) {
all_ops.insert(ggml_unary_op_name((enum ggml_unary_op)i));
}
for (int i = 0; i < GGML_GLU_OP_COUNT; i++) {
all_ops.insert(ggml_glu_op_name((enum ggml_glu_op)i));
}
auto test_cases = make_test_cases_eval();
std::set<std::string> tested_ops;

ggml_init_params params = {
/* .mem_size = */ ggml_tensor_overhead()*128 + ggml_graph_overhead(),
/* .mem_base = */ NULL,
/* .no_alloc = */ true,
};

for (auto & test_case : test_cases) {
ggml_context * ctx = ggml_init(params);
if (ctx) {
test_case->mode = MODE_TEST;
ggml_tensor * out = test_case->build_graph(ctx);
if (out && out->op != GGML_OP_NONE) {
if (out->op == GGML_OP_UNARY) {
tested_ops.insert(ggml_unary_op_name(ggml_get_unary_op(out)));
} else if (out->op == GGML_OP_GLU) {
tested_ops.insert(ggml_glu_op_name(ggml_get_glu_op(out)));
} else {
tested_ops.insert(ggml_op_name(out->op));
}
}
ggml_free(ctx);
}
}
std::set<std::string> covered_ops;
std::set<std::string> uncovered_ops;
for (const auto & op : all_ops) {
if (tested_ops.count(op) > 0) {
covered_ops.insert(op);
} else {
uncovered_ops.insert(op);
}
}

printf("Operations covered by tests (%zu):\n", covered_ops.size());
for (const auto & op : covered_ops) {
printf(" ✓ %s\n", op.c_str());
}
printf("\nOperations without tests (%zu):\n", uncovered_ops.size());
for (const auto & op : uncovered_ops) {
printf(" ✗ %s\n", op.c_str());
}

printf("\nCoverage Summary:\n");
printf(" Total operations: %zu\n", all_ops.size());
printf(" Tested operations: %zu\n", covered_ops.size());
printf(" Untested operations: %zu\n", uncovered_ops.size());
printf(" Coverage: %.1f%%\n", (double)covered_ops.size() / all_ops.size() * 100.0);
}

static void usage(char ** argv) {
printf("Usage: %s [mode] [-o <op,..>] [-b <backend>] [-p <params regex>] [--output <console|sql|csv>]\n", argv[0]);
printf("Usage: %s [mode] [-o <op,..>] [-b <backend>] [-p <params regex>] [--output <console|sql|csv>] [--list-ops] [--show-coverage]\n", argv[0]);
printf(" valid modes:\n");
printf(" - test (default, compare with CPU backend for correctness)\n");
printf(" - grad (compare gradients from backpropagation with method of finite differences)\n");
Expand All @@ -6639,6 +6722,8 @@ static void usage(char ** argv) {
printf(" op names for -o are as given by ggml_op_desc() (e.g. ADD, MUL_MAT, etc),\n");
printf(" optionally including the full test case string (e.g. \"ADD(type=f16,ne=[1,1,8,1],nr=[1,1,1,1],nf=1)\")\n");
printf(" --output specifies output format (default: console, options: console, sql, csv)\n");
printf(" --list-ops lists all available GGML operations\n");
printf(" --show-coverage shows test coverage\n");
}

int main(int argc, char ** argv) {
Expand Down Expand Up @@ -6688,6 +6773,12 @@ int main(int argc, char ** argv) {
usage(argv);
return 1;
}
} else if (strcmp(argv[i], "--list-ops") == 0) {
list_all_ops();
return 0;
} else if (strcmp(argv[i], "--show-coverage") == 0) {
show_test_coverage();
return 0;
} else {
usage(argv);
return 1;
Expand Down
Loading