4
4
#include " ggml-cpp.h"
5
5
6
6
#include < cinttypes>
7
+ #include < cstdlib>
7
8
#include < string>
8
9
#include < vector>
9
10
#include < memory>
@@ -97,6 +98,7 @@ enum rpc_cmd {
97
98
RPC_CMD_GET_ALLOC_SIZE,
98
99
RPC_CMD_HELLO,
99
100
RPC_CMD_COUNT,
101
+ RPC_CMD_AUTH,
100
102
};
101
103
102
104
// Try RPC_CMD_SET_TENSOR_HASH first when data size is larger than this threshold
@@ -108,6 +110,15 @@ struct rpc_msg_hello_rsp {
108
110
uint8_t patch;
109
111
};
110
112
113
+ struct rpc_msg_auth_req {
114
+ uint16_t length;
115
+ uint8_t token[256 ];
116
+ };
117
+
118
+ struct rpc_msg_auth_resp {
119
+ bool result;
120
+ };
121
+
111
122
struct rpc_msg_get_alloc_size_req {
112
123
rpc_tensor tensor;
113
124
};
@@ -426,8 +437,32 @@ static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cm
426
437
// RPC client-side implementation
427
438
428
439
static bool check_server_version (const std::shared_ptr<socket_t > & sock) {
440
+ const char * auth_token_s = std::getenv (" GGML_RPC_TOKEN" );
441
+
442
+ if (auth_token_s == nullptr ) {
443
+ fprintf (stderr, " No authentication token secret found in environment\n " );
444
+ return false ;
445
+ }
446
+
447
+ rpc_msg_auth_req auth_request;
448
+ auth_request.length = strlen (auth_token_s);
449
+ snprintf ((char *)auth_request.token ,
450
+ sizeof (auth_request.token ),
451
+ " %.*s" ,
452
+ (int )sizeof (auth_request.token )-1 ,
453
+ auth_token_s);
454
+
455
+ rpc_msg_auth_resp auth_response;
456
+ bool status = send_rpc_cmd (sock, RPC_CMD_AUTH, &auth_request, sizeof (rpc_msg_auth_req), &auth_response, sizeof (rpc_msg_auth_resp));
457
+ RPC_STATUS_ASSERT (status);
458
+
459
+ if (auth_response.result == false ) {
460
+ fprintf (stderr, " Failed to authenticate to RPC server\n " );
461
+ return false ;
462
+ }
463
+
429
464
rpc_msg_hello_rsp response;
430
- bool status = send_rpc_cmd (sock, RPC_CMD_HELLO, nullptr , 0 , &response, sizeof (response));
465
+ status = send_rpc_cmd (sock, RPC_CMD_HELLO, nullptr , 0 , &response, sizeof (response));
431
466
RPC_STATUS_ASSERT (status);
432
467
if (response.major != RPC_PROTO_MAJOR_VERSION || response.minor > RPC_PROTO_MINOR_VERSION) {
433
468
fprintf (stderr, " RPC server version mismatch: %d.%d.%d\n " , response.major , response.minor , response.patch );
@@ -1371,14 +1406,86 @@ rpc_server::~rpc_server() {
1371
1406
}
1372
1407
}
1373
1408
1409
+ // Implementation borrowed from https://github.com/chmike/cst_time_memcmp
1410
+ static int cst_time_memcmp (const void *m1, const void *m2, size_t n) {
1411
+ const unsigned char *pm1 = (const unsigned char *)m1;
1412
+ const unsigned char *pm2 = (const unsigned char *)m2;
1413
+ int res = 0 , diff;
1414
+ if (n > 0 ) {
1415
+ do {
1416
+ --n;
1417
+ diff = pm1[n] - pm2[n];
1418
+ res = (res & -!diff) | diff;
1419
+ } while (n != 0 );
1420
+ }
1421
+ return (res > 0 ) - (res < 0 );
1422
+ }
1423
+
1374
1424
static void rpc_serve_client (ggml_backend_t backend, const char * cache_dir,
1375
1425
sockfd_t sockfd, size_t free_mem, size_t total_mem) {
1426
+
1427
+ const char * auth_token_s = std::getenv (" GGML_RPC_TOKEN" );
1428
+ if (auth_token_s == nullptr ) {
1429
+ fprintf (stderr, " [%s] Authentication token secret not set\n " , __func__);
1430
+ return ;
1431
+ }
1432
+
1433
+ size_t auth_token_s_len = strlen (auth_token_s);
1434
+
1376
1435
rpc_server server (backend, cache_dir);
1377
1436
uint8_t cmd;
1437
+
1438
+ if (!recv_data (sockfd, &cmd, 1 )) {
1439
+ return ;
1440
+ }
1441
+
1442
+ // The first command sent by the client must be AUTH
1443
+ if (cmd != RPC_CMD_AUTH) {
1444
+ fprintf (stderr, " Expected AUTH command, update client\n " );
1445
+ return ;
1446
+ }
1447
+
1448
+ rpc_msg_auth_req request;
1449
+ if (!recv_msg (sockfd, &request, sizeof (request))) {
1450
+ fprintf (stderr, " Failed to process AUTH request, update client\n " );
1451
+ return ;
1452
+ }
1453
+
1454
+ rpc_msg_auth_resp auth_response;
1455
+
1456
+ // This is insecure for the following reasons:
1457
+ // 0) It is probably susceptible to cache timing attacks
1458
+ // 1) It may leak the size of the secret auth token
1459
+ // 2) It can be brute forced
1460
+ // 3) It compares secrets directly, not their hashes
1461
+ // 4) It can be intercepted on the wire (use socat/openssl)
1462
+ // 5) The token doesn't expire
1463
+ if (request.length != auth_token_s_len ||
1464
+ cst_time_memcmp ((void *) auth_token_s, (void *) &request.token , auth_token_s_len) != 0 ) {
1465
+ struct sockaddr_in peer_addr;
1466
+ socklen_t peer_len = sizeof (peer_addr);
1467
+
1468
+ if (getpeername (sockfd, (struct sockaddr *)&peer_addr, &peer_len) == 0 ) {
1469
+ char *ip = inet_ntoa (peer_addr.sin_addr );
1470
+ fprintf (stderr, " [%s] Invalid authentication token from %s\n " ,
1471
+ __func__, ip);
1472
+ } else {
1473
+ fprintf (stderr, " [%s] Invalid authentication token from unknown (getpeername failed)\n " ,
1474
+ __func__);
1475
+ }
1476
+ auth_response.result = false ;
1477
+ send_msg (sockfd, &auth_response, sizeof (auth_response));
1478
+ return ;
1479
+ }
1480
+
1481
+ auth_response.result = true ;
1482
+ send_msg (sockfd, &auth_response, sizeof (auth_response));
1483
+
1378
1484
if (!recv_data (sockfd, &cmd, 1 )) {
1379
1485
return ;
1380
1486
}
1381
- // the first command sent by the client must be HELLO
1487
+
1488
+ // The second command sent by the client must be HELLO
1382
1489
if (cmd != RPC_CMD_HELLO) {
1383
1490
fprintf (stderr, " Expected HELLO command, update client\n " );
1384
1491
return ;
0 commit comments