Skip to content

Commit bea53a6

Browse files
EAddariocompilade
authored andcommitted
imatrix: add option to display importance score statistics for a given imatrix file (ggml-org#12718)
* Add --show-statistics option * Add --show-statistics logic * Add tensor name parsing * Tidy output format * Fix typo in title * Improve tensor influence ranking * Add better statistics * Change statistics' sort order * Add Cosine Similarity * Add header search path * Change header search path to private * Add weighted statistics per layer * Update report title * Refactor compute_statistics out of main * Refactor compute_cossim out of load_imatrix * Refactor compute_statistics out of load_imatrix * Move imatrix statistics calculation into its own functions * Add checks and validations * Remove unnecessary include directory * Rename labels * Add m_stats getter and refactor compute_statistics out of load_imatrix * Refactor variable names * Minor cosmetic change * Retrigger checks (empty commit) * Rerun checks (empty commit) * Fix unnecessary type promotion Co-authored-by: compilade <[email protected]> * Reverting change to improve code readability * Rerun checks (empty commit) * Rerun checks (empty commit) * Rerun checks - third time's the Charm 🤞 (empty commit) * Minor cosmetic change * Update README * Fix typo * Update README * Rerun checks (empty commit) * Re-implement changes on top of ggml-org#9400 * Update README.md * Update README * Update README.md Co-authored-by: compilade <[email protected]> * Update README.md Co-authored-by: compilade <[email protected]> * Update README.md * Remove duplicate option in print_usage() * Update README.md * Update README.md Co-authored-by: compilade <[email protected]> * Update README.md Co-authored-by: compilade <[email protected]> * Remove input check * Remove commented out code --------- Co-authored-by: compilade <[email protected]>
1 parent d4d1522 commit bea53a6

File tree

1 file changed

+109
-2
lines changed

1 file changed

+109
-2
lines changed

ggml/src/ggml-rpc/ggml-rpc.cpp

Lines changed: 109 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "ggml-cpp.h"
55

66
#include <cinttypes>
7+
#include <cstdlib>
78
#include <string>
89
#include <vector>
910
#include <memory>
@@ -97,6 +98,7 @@ enum rpc_cmd {
9798
RPC_CMD_GET_ALLOC_SIZE,
9899
RPC_CMD_HELLO,
99100
RPC_CMD_COUNT,
101+
RPC_CMD_AUTH,
100102
};
101103

102104
// Try RPC_CMD_SET_TENSOR_HASH first when data size is larger than this threshold
@@ -108,6 +110,15 @@ struct rpc_msg_hello_rsp {
108110
uint8_t patch;
109111
};
110112

113+
struct rpc_msg_auth_req {
114+
uint16_t length;
115+
uint8_t token[256];
116+
};
117+
118+
struct rpc_msg_auth_resp {
119+
bool result;
120+
};
121+
111122
struct rpc_msg_get_alloc_size_req {
112123
rpc_tensor tensor;
113124
};
@@ -426,8 +437,32 @@ static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cm
426437
// RPC client-side implementation
427438

428439
static bool check_server_version(const std::shared_ptr<socket_t> & sock) {
440+
const char * auth_token_s = std::getenv("GGML_RPC_TOKEN");
441+
442+
if (auth_token_s == nullptr) {
443+
fprintf(stderr, "No authentication token secret found in environment\n");
444+
return false;
445+
}
446+
447+
rpc_msg_auth_req auth_request;
448+
auth_request.length = strlen(auth_token_s);
449+
snprintf((char *)auth_request.token,
450+
sizeof(auth_request.token),
451+
"%.*s",
452+
(int)sizeof(auth_request.token)-1,
453+
auth_token_s);
454+
455+
rpc_msg_auth_resp auth_response;
456+
bool status = send_rpc_cmd(sock, RPC_CMD_AUTH, &auth_request, sizeof(rpc_msg_auth_req), &auth_response, sizeof(rpc_msg_auth_resp));
457+
RPC_STATUS_ASSERT(status);
458+
459+
if (auth_response.result == false) {
460+
fprintf(stderr, "Failed to authenticate to RPC server\n");
461+
return false;
462+
}
463+
429464
rpc_msg_hello_rsp response;
430-
bool status = send_rpc_cmd(sock, RPC_CMD_HELLO, nullptr, 0, &response, sizeof(response));
465+
status = send_rpc_cmd(sock, RPC_CMD_HELLO, nullptr, 0, &response, sizeof(response));
431466
RPC_STATUS_ASSERT(status);
432467
if (response.major != RPC_PROTO_MAJOR_VERSION || response.minor > RPC_PROTO_MINOR_VERSION) {
433468
fprintf(stderr, "RPC server version mismatch: %d.%d.%d\n", response.major, response.minor, response.patch);
@@ -1371,14 +1406,86 @@ rpc_server::~rpc_server() {
13711406
}
13721407
}
13731408

1409+
// Implementation borrowed from https://github.com/chmike/cst_time_memcmp
1410+
static int cst_time_memcmp(const void *m1, const void *m2, size_t n) {
1411+
const unsigned char *pm1 = (const unsigned char*)m1;
1412+
const unsigned char *pm2 = (const unsigned char*)m2;
1413+
int res = 0, diff;
1414+
if (n > 0) {
1415+
do {
1416+
--n;
1417+
diff = pm1[n] - pm2[n];
1418+
res = (res & -!diff) | diff;
1419+
} while (n != 0);
1420+
}
1421+
return (res > 0) - (res < 0);
1422+
}
1423+
13741424
static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir,
13751425
sockfd_t sockfd, size_t free_mem, size_t total_mem) {
1426+
1427+
const char * auth_token_s = std::getenv("GGML_RPC_TOKEN");
1428+
if (auth_token_s == nullptr) {
1429+
fprintf(stderr, "[%s] Authentication token secret not set\n", __func__);
1430+
return;
1431+
}
1432+
1433+
size_t auth_token_s_len = strlen(auth_token_s);
1434+
13761435
rpc_server server(backend, cache_dir);
13771436
uint8_t cmd;
1437+
1438+
if (!recv_data(sockfd, &cmd, 1)) {
1439+
return;
1440+
}
1441+
1442+
// The first command sent by the client must be AUTH
1443+
if (cmd != RPC_CMD_AUTH) {
1444+
fprintf(stderr, "Expected AUTH command, update client\n");
1445+
return;
1446+
}
1447+
1448+
rpc_msg_auth_req request;
1449+
if (!recv_msg(sockfd, &request, sizeof(request))) {
1450+
fprintf(stderr, "Failed to process AUTH request, update client\n");
1451+
return;
1452+
}
1453+
1454+
rpc_msg_auth_resp auth_response;
1455+
1456+
// This is insecure for the following reasons:
1457+
// 0) It is probably susceptible to cache timing attacks
1458+
// 1) It may leak the size of the secret auth token
1459+
// 2) It can be brute forced
1460+
// 3) It compares secrets directly, not their hashes
1461+
// 4) It can be intercepted on the wire (use socat/openssl)
1462+
// 5) The token doesn't expire
1463+
if (request.length != auth_token_s_len ||
1464+
cst_time_memcmp((void *) auth_token_s, (void *) &request.token, auth_token_s_len) != 0) {
1465+
struct sockaddr_in peer_addr;
1466+
socklen_t peer_len = sizeof(peer_addr);
1467+
1468+
if (getpeername(sockfd, (struct sockaddr *)&peer_addr, &peer_len) == 0) {
1469+
char *ip = inet_ntoa(peer_addr.sin_addr);
1470+
fprintf(stderr, "[%s] Invalid authentication token from %s\n",
1471+
__func__, ip);
1472+
} else {
1473+
fprintf(stderr, "[%s] Invalid authentication token from unknown (getpeername failed)\n",
1474+
__func__);
1475+
}
1476+
auth_response.result = false;
1477+
send_msg(sockfd, &auth_response, sizeof(auth_response));
1478+
return;
1479+
}
1480+
1481+
auth_response.result = true;
1482+
send_msg(sockfd, &auth_response, sizeof(auth_response));
1483+
13781484
if (!recv_data(sockfd, &cmd, 1)) {
13791485
return;
13801486
}
1381-
// the first command sent by the client must be HELLO
1487+
1488+
// The second command sent by the client must be HELLO
13821489
if (cmd != RPC_CMD_HELLO) {
13831490
fprintf(stderr, "Expected HELLO command, update client\n");
13841491
return;

0 commit comments

Comments
 (0)