diff --git a/README.md b/README.md index 023849bf..2d6b743b 100644 --- a/README.md +++ b/README.md @@ -9,3 +9,10 @@ This library is the base for [ESPAsyncWebServer](https://github.com/me-no-dev/ES ## AsyncClient and AsyncServer The base classes on which everything else is built. They expose all possible scenarios, but are really raw and require more skills to use. + +## TLS support +Support for TLS is added using mbed TLS, for now only the client part is supported. You can enable this by adding the flag ASYNC_TCP_SSL_ENABLED to your build flags (-DASYNC_TCP_SSL_ENABLED). If you'd like to set a root certificate you can use the setRootCa function on AsyncClient. Feel free to add support for the server side as well :-) + +In addition to the regular certificate based cipher suites there is also support for Pre-Shared Key +cipher suites. Use `setPsk` to define the PSK identifier and PSK itself. The PSK needs to be +provided in the form of a hex string (and easy way to generate a PSK is to use md5sum). diff --git a/component.mk b/component.mk index bb5bb161..65af2463 100644 --- a/component.mk +++ b/component.mk @@ -1,3 +1,5 @@ COMPONENT_ADD_INCLUDEDIRS := src COMPONENT_SRCDIRS := src CXXFLAGS += -fno-rtti +CXXFLAGS += -DASYNC_TCP_SSL_ENABLED=1 +CPPFLAGS += -DASYNC_TCP_SSL_ENABLED=1 \ No newline at end of file diff --git a/src/AsyncTCP.cpp b/src/AsyncTCP.cpp index f06ccccb..2d007132 100644 --- a/src/AsyncTCP.cpp +++ b/src/AsyncTCP.cpp @@ -408,7 +408,6 @@ static tcp_pcb * _tcp_listen_with_backlog(tcp_pcb * pcb, uint8_t backlog) { /* Async TCP Client */ - AsyncClient::AsyncClient(tcp_pcb* pcb) : _connect_cb(0) , _connect_cb_arg(0) @@ -425,6 +424,14 @@ AsyncClient::AsyncClient(tcp_pcb* pcb) , _timeout_cb(0) , _timeout_cb_arg(0) , _pcb_busy(false) +#if ASYNC_TCP_SSL_ENABLED +, _root_ca_len(0) +, _root_ca(NULL) +, _pcb_secure(false) +, _handshake_done(true) +, _psk_ident(0) +, _psk(0) +#endif // ASYNC_TCP_SSL_ENABLED , _pcb_sent_at(0) , _close_pcb(false) , _ack_pcb(true) @@ -436,8 +443,6 @@ AsyncClient::AsyncClient(tcp_pcb* pcb) , next(NULL) , _in_lwip_thread(false) { - //ets_printf("+: 0x%08x\n", (uint32_t)this); - _pcb = pcb; if(_pcb){ _rx_last_packet = millis(); @@ -453,11 +458,13 @@ AsyncClient::AsyncClient(tcp_pcb* pcb) AsyncClient::~AsyncClient(){ if(_pcb) _close(); - - //ets_printf("-: 0x%08x\n", (uint32_t)this); } +#if ASYNC_TCP_SSL_ENABLED +bool AsyncClient::connect(IPAddress ip, uint16_t port, bool secure){ +#else bool AsyncClient::connect(IPAddress ip, uint16_t port){ +#endif // ASYNC_TCP_SSL_ENABLED if (_pcb){ log_w("already connected, state %d", _pcb->state); return false; @@ -477,6 +484,11 @@ bool AsyncClient::connect(IPAddress ip, uint16_t port){ return false; } +#if ASYNC_TCP_SSL_ENABLED + _pcb_secure = secure; + _handshake_done = !secure; +#endif // ASYNC_TCP_SSL_ENABLED + tcp_arg(pcb, this); tcp_err(pcb, &_tcp_error); if(_in_lwip_thread){ @@ -487,6 +499,18 @@ bool AsyncClient::connect(IPAddress ip, uint16_t port){ return true; } +#if ASYNC_TCP_SSL_ENABLED +void AsyncClient::setRootCa(const char* rootca, const size_t len) { + _root_ca = (char*)rootca; + _root_ca_len = len; +} + +void AsyncClient::setPsk(const char* psk_ident, const char* psk) { + _psk_ident = psk_ident; + _psk = psk; +} +#endif // ASYNC_TCP_SSL_ENABLED + AsyncClient& AsyncClient::operator=(const AsyncClient& other){ if (_pcb) _close(); @@ -499,37 +523,80 @@ AsyncClient& AsyncClient::operator=(const AsyncClient& other){ tcp_sent(_pcb, &_tcp_sent); tcp_err(_pcb, &_tcp_error); tcp_poll(_pcb, &_tcp_poll, 1); + +#if ASYNC_TCP_SSL_ENABLED + if(tcp_ssl_has(_pcb)){ + _pcb_secure = true; + _handshake_done = false; + tcp_ssl_arg(_pcb, this); + tcp_ssl_data(_pcb, &_s_data); + tcp_ssl_handshake(_pcb, &_s_handshake); + tcp_ssl_err(_pcb, &_s_ssl_error); + } else { + _pcb_secure = false; + _handshake_done = true; + } +#endif // ASYNC_TCP_SSL_ENABLED } return *this; } int8_t AsyncClient::_connected(void* pcb, int8_t err){ _pcb = reinterpret_cast(pcb); + if(_pcb){ _rx_last_packet = millis(); _pcb_busy = false; tcp_recv(_pcb, &_tcp_recv); tcp_sent(_pcb, &_tcp_sent); tcp_poll(_pcb, &_tcp_poll, 1); +#if ASYNC_TCP_SSL_ENABLED + if(_pcb_secure){ + bool err = false; + if(_root_ca) { + err = tcp_ssl_new_client(_pcb, _hostname.empty() ? NULL : _hostname.c_str(), _root_ca, _root_ca_len) < 0; + } else { + err = tcp_ssl_new_psk_client(_pcb, _psk_ident, _psk) < 0; + } + if (err) { + log_e("closing...."); + return _close(); + } + + tcp_ssl_arg(_pcb, this); + tcp_ssl_data(_pcb, &_s_data); + tcp_ssl_handshake(_pcb, &_s_handshake); + tcp_ssl_err(_pcb, &_s_ssl_error); + } +#endif // ASYNC_TCP_SSL_ENABLED } + _in_lwip_thread = true; +#if ASYNC_TCP_SSL_ENABLED + if(!_pcb_secure && _connect_cb) +#else if(_connect_cb) +#endif // ASYNC_TCP_SSL_ENABLED _connect_cb(_connect_cb_arg, this); _in_lwip_thread = false; + return ERR_OK; } int8_t AsyncClient::_close(){ - //ets_printf("X: 0x%08x\n", (uint32_t)this); int8_t err = ERR_OK; + if(_pcb) { - //log_i(""); +#if ASYNC_TCP_SSL_ENABLED + if(_pcb_secure){ + tcp_ssl_free(_pcb); + } +#endif // ASYNC_TCP_SSL_ENABLED tcp_arg(_pcb, NULL); tcp_sent(_pcb, NULL); tcp_recv(_pcb, NULL); tcp_err(_pcb, NULL); tcp_poll(_pcb, NULL, 0); - _tcp_clear_events(this); if(_in_lwip_thread){ err = tcp_close(_pcb); } else { @@ -539,6 +606,7 @@ int8_t AsyncClient::_close(){ err = abort(); } _pcb = NULL; + // _tcp_clear_events(this); if(_discard_cb) _discard_cb(_discard_cb_arg, this); } @@ -546,6 +614,7 @@ int8_t AsyncClient::_close(){ } void AsyncClient::_error(int8_t err) { + log_e("Error!! %d", err); if(_pcb){ tcp_arg(_pcb, NULL); tcp_sent(_pcb, NULL); @@ -560,7 +629,19 @@ void AsyncClient::_error(int8_t err) { _discard_cb(_discard_cb_arg, this); } +#if ASYNC_TCP_SSL_ENABLED +void AsyncClient::_ssl_error(int8_t err){ + if(_error_cb) + _error_cb(_error_cb_arg, this, err+64); +} +#endif // ASYNC_TCP_SSL_ENABLED + int8_t AsyncClient::_sent(tcp_pcb* pcb, uint16_t len) { +#if ASYNC_TCP_SSL_ENABLED + if (_pcb_secure && !_handshake_done) + return ERR_OK; +#endif // ASYNC_TCP_SSL_ENABLED + _in_lwip_thread = false; _rx_last_packet = millis(); //log_i("%u", len); @@ -572,12 +653,13 @@ int8_t AsyncClient::_sent(tcp_pcb* pcb, uint16_t len) { int8_t AsyncClient::_recv(tcp_pcb* pcb, pbuf* pb, int8_t err) { if(!_pcb || pcb != _pcb){ - log_e("0x%08x != 0x%08x", (uint32_t)pcb, (uint32_t)_pcb); - if(pb){ - pbuf_free(pb); - } - return ERR_OK; - } + log_e("0x%08x != 0x%08x", (uint32_t)pcb, (uint32_t)_pcb); + if(pb){ + pbuf_free(pb); + } + return ERR_OK; + } + _in_lwip_thread = false; if(pb == NULL){ return _close(); @@ -585,6 +667,21 @@ int8_t AsyncClient::_recv(tcp_pcb* pcb, pbuf* pb, int8_t err) { while(pb != NULL){ _rx_last_packet = millis(); +#if ASYNC_TCP_SSL_ENABLED + if(_pcb_secure){ + // log_i("_recv: %d\n", pb->tot_len); + int read_bytes = tcp_ssl_read(pcb, pb); + if(read_bytes < 0){ + if (read_bytes != MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY) { + log_e("_recv err: %d\n", read_bytes); + _close(); + } + + //return read_bytes; + } + return ERR_OK; + } +#endif // ASYNC_TCP_SSL_ENABLED //we should not ack before we assimilate the data //log_i("%u", pb->len); //Serial.write((const uint8_t *)pb->payload, pb->len); @@ -631,6 +728,12 @@ int8_t AsyncClient::_poll(tcp_pcb* pcb){ _close(); return ERR_OK; } +#if ASYNC_TCP_SSL_ENABLED + if(_pcb_secure && !_handshake_done && (now - _rx_last_packet) >= 2000){ + _close(); + return ERR_OK; + } +#endif // ASYNC_TCP_SSL_ENABLED // Everything is fine if(_poll_cb) _poll_cb(_poll_cb_arg, this); @@ -639,8 +742,13 @@ int8_t AsyncClient::_poll(tcp_pcb* pcb){ void AsyncClient::_dns_found(struct ip_addr *ipaddr){ _in_lwip_thread = true; + if(ipaddr){ +#if ASYNC_TCP_SSL_ENABLED + connect(IPAddress(ipaddr->u_addr.ip4.addr), _connect_port, _pcb_secure); +#else connect(IPAddress(ipaddr->u_addr.ip4.addr), _connect_port); +#endif // ASYNC_TCP_SSL_ENABLED } else { log_e("dns fail"); if(_error_cb) @@ -655,13 +763,29 @@ bool AsyncClient::operator==(const AsyncClient &other) { return _pcb == other._pcb; } +#if ASYNC_TCP_SSL_ENABLED +bool AsyncClient::connect(const char* host, uint16_t port, bool secure){ +#else bool AsyncClient::connect(const char* host, uint16_t port){ +#endif // ASYNC_TCP_SSL_ENABLED ip_addr_t addr; + err_t err = dns_gethostbyname(host, &addr, (dns_found_callback)&_s_dns_found, this); if(err == ERR_OK) { + _hostname = host; + +#if ASYNC_TCP_SSL_ENABLED + return connect(IPAddress(addr.u_addr.ip4.addr), port, secure); +#else return connect(IPAddress(addr.u_addr.ip4.addr), port); +#endif // ASYNC_TCP_SSL_ENABLED } else if(err == ERR_INPROGRESS) { + _hostname = host; _connect_port = port; +#if ASYNC_TCP_SSL_ENABLED + _pcb_secure = secure; + _handshake_done = !secure; +#endif // ASYNC_TCP_SSL_ENABLED return true; } log_e("error: %d", err); @@ -727,13 +851,25 @@ size_t AsyncClient::write(const char* data, size_t size, uint8_t apiflags) { return will_send; } - size_t AsyncClient::add(const char* data, size_t size, uint8_t apiflags) { if(!_pcb || size == 0 || data == NULL) return 0; size_t room = space(); if(!room) return 0; +#if ASYNC_TCP_SSL_ENABLED + if(_pcb_secure){ + int sent = tcp_ssl_write(_pcb, (uint8_t*)data, size); + if(sent >= 0){ + // @ToDo: ??? + //_tx_unacked_len += sent; + return sent; + } + //log_i("add: tcp_ssl_write: %d", sent); + _close(); + return 0; + } +#endif // ASYNC_TCP_SSL_ENABLED size_t will_send = (room < size) ? room : size; int8_t err = ERR_OK; if(_in_lwip_thread){ @@ -747,6 +883,10 @@ size_t AsyncClient::add(const char* data, size_t size, uint8_t apiflags) { } bool AsyncClient::send(){ +#if ASYNC_TCP_SSL_ENABLED + if(_pcb_secure) + return true; +#endif // ASYNC_TCP_SSL_ENABLED int8_t err = ERR_OK; if(_in_lwip_thread){ err = tcp_output(_pcb); @@ -958,7 +1098,6 @@ void AsyncClient::onPoll(AcConnectHandler cb, void* arg){ _poll_cb_arg = arg; } - void AsyncClient::_s_dns_found(const char * name, struct ip_addr * ipaddr, void * arg){ if(arg){ reinterpret_cast(arg)->_dns_found(ipaddr); @@ -981,7 +1120,7 @@ int8_t AsyncClient::_s_recv(void * arg, struct tcp_pcb * pcb, struct pbuf *pb, i reinterpret_cast(arg)->_recv(pcb, pb, err); } else { if(pb){ - pbuf_free(pb); + pbuf_free(pb); } log_e("Bad Args: 0x%08x 0x%08x", arg, pcb); } @@ -1014,6 +1153,25 @@ int8_t AsyncClient::_s_connected(void * arg, void * pcb, int8_t err){ return ERR_OK; } +#if ASYNC_TCP_SSL_ENABLED +void AsyncClient::_s_data(void *arg, struct tcp_pcb *tcp, uint8_t * data, size_t len){ + AsyncClient *c = reinterpret_cast(arg); + if(c->_recv_cb) + c->_recv_cb(c->_recv_cb_arg, c, data, len); +} + +void AsyncClient::_s_handshake(void *arg, struct tcp_pcb *tcp, struct tcp_ssl_pcb* ssl){ + AsyncClient *c = reinterpret_cast(arg); + c->_handshake_done = true; + if(c->_connect_cb) + c->_connect_cb(c->_connect_cb_arg, c); +} + +void AsyncClient::_s_ssl_error(void *arg, struct tcp_pcb *tcp, int8_t err){ + reinterpret_cast(arg)->_ssl_error(err); +} +#endif // ASYNC_TCP_SSL_ENABLED + const char * AsyncClient::errorToString(int8_t error){ switch(error){ case 0: return "OK"; diff --git a/src/AsyncTCP.h b/src/AsyncTCP.h index 6cd0fca2..8117a63f 100644 --- a/src/AsyncTCP.h +++ b/src/AsyncTCP.h @@ -22,12 +22,16 @@ #ifndef ASYNCTCP_H_ #define ASYNCTCP_H_ +#include "Arduino.h" #include "IPAddress.h" #include +#include +#include extern "C" { #include "freertos/semphr.h" #include "lwip/pbuf.h" } +#include "tcp_mbedtls.h" class AsyncClient; @@ -48,6 +52,7 @@ struct ip_addr; class AsyncClient { protected: tcp_pcb* _pcb; + std::string _hostname; AcConnectHandler _connect_cb; void* _connect_cb_arg; @@ -67,6 +72,14 @@ class AsyncClient { void* _poll_cb_arg; bool _pcb_busy; +#if ASYNC_TCP_SSL_ENABLED + size_t _root_ca_len; + char* _root_ca; + bool _pcb_secure; + bool _handshake_done; + const char* _psk_ident; + const char* _psk; +#endif // ASYNC_TCP_SSL_ENABLED uint32_t _pcb_sent_at; bool _close_pcb; bool _ack_pcb; @@ -79,10 +92,17 @@ class AsyncClient { int8_t _close(); int8_t _connected(void* pcb, int8_t err); void _error(int8_t err); +#if ASYNC_TCP_SSL_ENABLED + void _ssl_error(int8_t err); +#endif // ASYNC_TCP_SSL_ENABLED int8_t _poll(tcp_pcb* pcb); int8_t _sent(tcp_pcb* pcb, uint16_t len); void _dns_found(struct ip_addr *ipaddr); - +#if ASYNC_TCP_SSL_ENABLED + static void _s_data(void *arg, struct tcp_pcb *tcp, uint8_t * data, size_t len); + static void _s_handshake(void *arg, struct tcp_pcb *tcp, struct tcp_ssl_pcb* ssl); + static void _s_ssl_error(void *arg, struct tcp_pcb *tcp, int8_t err); +#endif // ASYNC_TCP_SSL_ENABLED public: AsyncClient* prev; @@ -99,8 +119,16 @@ class AsyncClient { bool operator!=(const AsyncClient &other) { return !(*this == other); } + +#if ASYNC_TCP_SSL_ENABLED + bool connect(IPAddress ip, uint16_t port, bool secure = false); + bool connect(const char* host, uint16_t port, bool secure = false); + void setRootCa(const char* rootca, const size_t len); + void setPsk(const char* psk_ident, const char* psk); +#else bool connect(IPAddress ip, uint16_t port); bool connect(const char* host, uint16_t port); +#endif // ASYNC_TCP_SSL_ENABLED void close(bool now = false); void stop(); int8_t abort(); @@ -166,6 +194,10 @@ class AsyncClient { bool _in_lwip_thread; }; +#if ASYNC_TCP_SSL_ENABLED +typedef std::function AcSSlFileHandler; +#endif + class AsyncServer { protected: uint16_t _port; @@ -182,6 +214,11 @@ class AsyncServer { AsyncServer(uint16_t port); ~AsyncServer(); void onClient(AcConnectHandler cb, void* arg); +#if ASYNC_TCP_SSL_ENABLED + // Dummy, so it compiles with ESP Async WebServer library enabled. + void onSslFileRequest(AcSSlFileHandler cb, void* arg) {}; + void beginSecure(const char *cert, const char *private_key_file, const char *password) {}; +#endif void begin(); void end(); void setNoDelay(bool nodelay); diff --git a/src/tcp_mbedtls.c b/src/tcp_mbedtls.c new file mode 100644 index 00000000..fc06d72b --- /dev/null +++ b/src/tcp_mbedtls.c @@ -0,0 +1,550 @@ +#if ASYNC_TCP_SSL_ENABLED + +#include "tcp_mbedtls.h" +#include "lwip/tcp.h" +#include "mbedtls/debug.h" +#include "mbedtls/esp_debug.h" +#include + +// #define TCP_SSL_DEBUG(...) ets_printf(__VA_ARGS__) +#define TCP_SSL_DEBUG(...) + +static const char pers[] = "esp32-tls"; + +static int handle_error(int err) { + if(err == -30848){ + return err; + } +#ifdef MBEDTLS_ERROR_C + char error_buf[100]; + mbedtls_strerror(err, error_buf, 100); + TCP_SSL_DEBUG("%s\n", error_buf); +#endif + TCP_SSL_DEBUG("MbedTLS message code: %d\n", err); + return err; +} + +/** + * Certificate verification callback for mbed TLS + * Here we only use it to display information on each cert in the chain + */ +// static int my_verify(void *data, mbedtls_x509_crt *crt, int depth, uint32_t *flags) { +// const uint32_t buf_size = 1024; +// char buf[buf_size]; +// (void) data; + +// mbedtls_printf("\nVerifying certificate at depth %d:\n", depth); +// mbedtls_x509_crt_info(buf, buf_size - 1, " ", crt); +// mbedtls_printf("%s", buf); + +// if (*flags == 0) +// mbedtls_printf("No verification issue for this certificate\n"); +// else +// { +// mbedtls_x509_crt_verify_info(buf, buf_size, " ! ", *flags); +// mbedtls_printf("%s\n", buf); +// } + +// return 0; +// } + +static uint8_t _tcp_ssl_has_client = 0; + +struct tcp_ssl_pcb { + struct tcp_pcb *tcp; + int fd; + mbedtls_ssl_context ssl_ctx; + mbedtls_ssl_config ssl_conf; + mbedtls_x509_crt ca_cert; + mbedtls_ctr_drbg_context drbg_ctx; + mbedtls_entropy_context entropy_ctx; + uint8_t type; + // int handshake; + void* arg; + tcp_ssl_data_cb_t on_data; + tcp_ssl_handshake_cb_t on_handshake; + tcp_ssl_error_cb_t on_error; + size_t last_wr; + struct pbuf *tcp_pbuf; + int pbuf_offset; + struct tcp_ssl_pcb* next; +}; + +typedef struct tcp_ssl_pcb tcp_ssl_t; + +static tcp_ssl_t * tcp_ssl_array = NULL; +static int tcp_ssl_next_fd = 0; + +int tcp_ssl_recv(void *ctx, unsigned char *buf, size_t len) { + tcp_ssl_t *tcp_ssl = (tcp_ssl_t*)ctx; + uint8_t *read_buf = NULL; + uint8_t *pread_buf = NULL; + u16_t recv_len = 0; + + if(tcp_ssl->tcp_pbuf == NULL || tcp_ssl->tcp_pbuf->tot_len == 0) { + TCP_SSL_DEBUG("tcp_ssl_recv: not yet ready to read: tcp_pbuf: 0x%X.\n", tcp_ssl->tcp_pbuf); + return MBEDTLS_ERR_SSL_WANT_READ; + } + + read_buf =(uint8_t*)calloc(tcp_ssl->tcp_pbuf->len + 1, sizeof(uint8_t)); + pread_buf = read_buf; + if (pread_buf != NULL){ + recv_len = pbuf_copy_partial(tcp_ssl->tcp_pbuf, read_buf, len, tcp_ssl->pbuf_offset); + TCP_SSL_DEBUG("tcp_ssl_recv: len: %d, recv_len: %d, pbuf_offset: %d, tcp_pbuf len: %d.\n", len, recv_len, tcp_ssl->pbuf_offset, tcp_ssl->tcp_pbuf->len); + tcp_ssl->pbuf_offset += recv_len; + } + + // Note: why copy again? + if (recv_len != 0) { + memcpy(buf, read_buf, recv_len); + } + + if(len < recv_len) { + TCP_SSL_DEBUG("tcp_ssl_recv: got %d bytes more than expected\n", recv_len - len); + } + + free(pread_buf); + pread_buf = NULL; + + if(recv_len == 0) { + return MBEDTLS_ERR_SSL_WANT_READ; + } + + return recv_len; +} + +int tcp_ssl_send(void *ctx, const unsigned char *buf, size_t len) { + TCP_SSL_DEBUG("tcp_ssl_send: ctx: 0x%X, buf: 0x%X, len: %d\n", ctx, buf, len); + + if(ctx == NULL) { + TCP_SSL_DEBUG("tcp_ssl_send: no context set\n"); + return -1; + } + + if(buf == NULL) { + TCP_SSL_DEBUG("tcp_ssl_send: buf not set\n"); + return -1; + } + + tcp_ssl_t *tcp_ssl = (tcp_ssl_t*)ctx; + size_t tcp_len = 0; + int err = ERR_OK; + + if (tcp_sndbuf(tcp_ssl->tcp) < len) { + tcp_len = tcp_sndbuf(tcp_ssl->tcp); + if(tcp_len == 0) { + TCP_SSL_DEBUG("ax_port_write: tcp_sndbuf is zero: %d\n", len); + return ERR_MEM; + } + } else { + tcp_len = len; + } + + if (tcp_len > 2 * tcp_ssl->tcp->mss) { + tcp_len = 2 * tcp_ssl->tcp->mss; + } + + err = tcp_write(tcp_ssl->tcp, buf, tcp_len, TCP_WRITE_FLAG_COPY); + if(err < ERR_OK) { + if (err == ERR_MEM) { + TCP_SSL_DEBUG("ax_port_write: No memory %d (%d)\n", tcp_len, len); + return err; + } + TCP_SSL_DEBUG("ax_port_write: tcp_write error: %d\n", err); + return err; + } else if (err == ERR_OK) { + //TCP_SSL_DEBUG("ax_port_write: tcp_output: %d / %d\n", tcp_len, len); + err = tcp_output(tcp_ssl->tcp); + if(err != ERR_OK) { + TCP_SSL_DEBUG("ax_port_write: tcp_output err: %d\n", err); + return err; + } + } + + tcp_ssl->last_wr += tcp_len; + + return tcp_len; +} + +uint8_t tcp_ssl_has_client() { + return _tcp_ssl_has_client; +} + +tcp_ssl_t * tcp_ssl_new(struct tcp_pcb *tcp) { + + if(tcp_ssl_next_fd < 0){ + tcp_ssl_next_fd = 0;//overflow + } + + tcp_ssl_t * new_item = (tcp_ssl_t*)malloc(sizeof(tcp_ssl_t)); + if(!new_item){ + TCP_SSL_DEBUG("tcp_ssl_new: failed to allocate tcp_ssl\n"); + return NULL; + } + + new_item->tcp = tcp; + new_item->arg = NULL; + new_item->on_data = NULL; + new_item->on_handshake = NULL; + new_item->on_error = NULL; + new_item->tcp_pbuf = NULL; + new_item->pbuf_offset = 0; + new_item->next = NULL; + + if(tcp_ssl_array == NULL){ + tcp_ssl_array = new_item; + } else { + tcp_ssl_t * item = tcp_ssl_array; + while(item->next != NULL) + item = item->next; + item->next = new_item; + } + + return new_item; +} + +tcp_ssl_t* tcp_ssl_get(struct tcp_pcb *tcp) { + if(tcp == NULL) { + return NULL; + } + tcp_ssl_t * item = tcp_ssl_array; + while(item && item->tcp != tcp){ + item = item->next; + } + return item; +} + +int tcp_ssl_new_client(struct tcp_pcb *tcp, const char* hostname, const char* root_ca, const size_t root_ca_len) { + tcp_ssl_t* tcp_ssl; + + if(tcp == NULL) { + return -1; + } + + if(tcp_ssl_get(tcp) != NULL){ + return -1; + } + + tcp_ssl = tcp_ssl_new(tcp); + if(tcp_ssl == NULL){ + return -1; + } + + mbedtls_entropy_init(&tcp_ssl->entropy_ctx); + mbedtls_ctr_drbg_init(&tcp_ssl->drbg_ctx); + mbedtls_ssl_init(&tcp_ssl->ssl_ctx); + mbedtls_ssl_config_init(&tcp_ssl->ssl_conf); + + mbedtls_ctr_drbg_seed(&tcp_ssl->drbg_ctx, mbedtls_entropy_func, + &tcp_ssl->entropy_ctx, (const unsigned char*)pers, strlen(pers)); + + if(mbedtls_ssl_config_defaults(&tcp_ssl->ssl_conf, + MBEDTLS_SSL_IS_CLIENT, + MBEDTLS_SSL_TRANSPORT_STREAM, + MBEDTLS_SSL_PRESET_DEFAULT)) { + TCP_SSL_DEBUG("error setting SSL config.\n"); + + tcp_ssl_free(tcp); + return -1; + } + + int ret = 0; + + if(root_ca != NULL && root_ca_len > 0) { + TCP_SSL_DEBUG("setting the root ca.\n"); + + mbedtls_x509_crt_init(&tcp_ssl->ca_cert); + + mbedtls_ssl_conf_authmode(&tcp_ssl->ssl_conf, MBEDTLS_SSL_VERIFY_REQUIRED); + + ret = mbedtls_x509_crt_parse(&tcp_ssl->ca_cert, (const unsigned char *)root_ca, root_ca_len); + if( ret < 0 ){ + TCP_SSL_DEBUG(" failed\n ! mbedtls_x509_crt_parse returned -0x%x\n\n", -ret); + return handle_error(ret); + } + + mbedtls_ssl_conf_ca_chain(&tcp_ssl->ssl_conf, &tcp_ssl->ca_cert, NULL); + } else { + mbedtls_ssl_conf_authmode(&tcp_ssl->ssl_conf, MBEDTLS_SSL_VERIFY_OPTIONAL); + } + + if(hostname != NULL) { + TCP_SSL_DEBUG("setting the hostname: %s\n", hostname); + if((ret = mbedtls_ssl_set_hostname(&tcp_ssl->ssl_ctx, hostname)) != 0){ + tcp_ssl_free(tcp); + + return handle_error(ret); + } + } + + mbedtls_ssl_conf_rng(&tcp_ssl->ssl_conf, mbedtls_ctr_drbg_random, &tcp_ssl->drbg_ctx); + // mbedtls_ssl_conf_verify(&tcp_ssl->ssl_conf, my_verify, NULL); + + if ((ret = mbedtls_ssl_setup(&tcp_ssl->ssl_ctx, &tcp_ssl->ssl_conf)) != 0) { + tcp_ssl_free(tcp); + + return handle_error(ret); + } + + mbedtls_ssl_set_bio(&tcp_ssl->ssl_ctx, (void*)tcp_ssl, tcp_ssl_send, tcp_ssl_recv, NULL); + + // Start handshake. + ret = mbedtls_ssl_handshake(&tcp_ssl->ssl_ctx); + if (ret != MBEDTLS_ERR_SSL_WANT_READ && ret != MBEDTLS_ERR_SSL_WANT_WRITE) { + TCP_SSL_DEBUG("handshake error!\n"); + return handle_error(ret); + } + + return ERR_OK; +} + +// Open an SSL connection using a PSK (pre-shared-key) cipher suite. +int tcp_ssl_new_psk_client(struct tcp_pcb *tcp, const char* psk_ident, const char* pskey) { + tcp_ssl_t* tcp_ssl; + + if(tcp == NULL) return -1; + if(tcp_ssl_get(tcp) != NULL) return -1; + + tcp_ssl = tcp_ssl_new(tcp); + if(tcp_ssl == NULL) return -1; + + mbedtls_entropy_init(&tcp_ssl->entropy_ctx); + mbedtls_ctr_drbg_init(&tcp_ssl->drbg_ctx); + mbedtls_ssl_init(&tcp_ssl->ssl_ctx); + mbedtls_ssl_config_init(&tcp_ssl->ssl_conf); + + mbedtls_ctr_drbg_seed(&tcp_ssl->drbg_ctx, mbedtls_entropy_func, + &tcp_ssl->entropy_ctx, (const unsigned char*)pers, strlen(pers)); + + if(mbedtls_ssl_config_defaults(&tcp_ssl->ssl_conf, + MBEDTLS_SSL_IS_CLIENT, + MBEDTLS_SSL_TRANSPORT_STREAM, + MBEDTLS_SSL_PRESET_DEFAULT)) { + TCP_SSL_DEBUG("error setting SSL config.\n"); + + tcp_ssl_free(tcp); + return -1; + } + + int ret = 0; + + TCP_SSL_DEBUG("setting the pre-shared key.\n"); + // convert PSK from hex string to binary + if ((strlen(pskey) & 1) != 0 || strlen(pskey) > 2*MBEDTLS_PSK_MAX_LEN) { + TCP_SSL_DEBUG(" failed\n ! pre-shared key not valid hex or too long\n\n"); + return -1; + } + unsigned char psk[MBEDTLS_PSK_MAX_LEN]; + size_t psk_len = strlen(pskey)/2; + for (int j=0; j= '0' && c <= '9') c -= '0'; + else if (c >= 'A' && c <= 'F') c -= 'A' - 10; + else if (c >= 'a' && c <= 'f') c -= 'a' - 10; + else return -1; + psk[j/2] = c<<4; + c = pskey[j+1]; + if (c >= '0' && c <= '9') c -= '0'; + else if (c >= 'A' && c <= 'F') c -= 'A' - 10; + else if (c >= 'a' && c <= 'f') c -= 'a' - 10; + else return -1; + psk[j/2] |= c; + } + // set mbedtls config + ret = mbedtls_ssl_conf_psk(&tcp_ssl->ssl_conf, psk, psk_len, + (const unsigned char *)psk_ident, strlen(psk_ident)); + if (ret != 0) { + TCP_SSL_DEBUG(" failed\n ! mbedtls_ssl_conf_psk returned -0x%x\n\n", -ret); + return handle_error(ret); + } + + mbedtls_ssl_conf_rng(&tcp_ssl->ssl_conf, mbedtls_ctr_drbg_random, &tcp_ssl->drbg_ctx); + + if ((ret = mbedtls_ssl_setup(&tcp_ssl->ssl_ctx, &tcp_ssl->ssl_conf)) != 0) { + tcp_ssl_free(tcp); + + return handle_error(ret); + } + + mbedtls_ssl_set_bio(&tcp_ssl->ssl_ctx, (void*)tcp_ssl, tcp_ssl_send, tcp_ssl_recv, NULL); + + // Start handshake. + ret = mbedtls_ssl_handshake(&tcp_ssl->ssl_ctx); + if (ret != MBEDTLS_ERR_SSL_WANT_READ && ret != MBEDTLS_ERR_SSL_WANT_WRITE) { + TCP_SSL_DEBUG("handshake error!\n"); + return handle_error(ret); + } + + return ERR_OK; +} + +int tcp_ssl_write(struct tcp_pcb *tcp, uint8_t *data, size_t len) { + if(tcp == NULL) { + return -1; + } + + tcp_ssl_t * tcp_ssl = tcp_ssl_get(tcp); + + if(tcp_ssl == NULL){ + return 0; + } + + tcp_ssl->last_wr = 0; + + int rc = mbedtls_ssl_write(&tcp_ssl->ssl_ctx, data, len); + + if (rc < 0){ + if (rc != MBEDTLS_ERR_SSL_WANT_READ && rc != MBEDTLS_ERR_SSL_WANT_WRITE) { + TCP_SSL_DEBUG("about to call mbedtls_ssl_write\n"); + return handle_error(rc); + } + if(rc != MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY) { + TCP_SSL_DEBUG("tcp_ssl_write error: %d\r\n", rc); + } + return rc; + } + + return tcp_ssl->last_wr; +} + +int tcp_ssl_read(struct tcp_pcb *tcp, struct pbuf *p) { + if(tcp == NULL) { + return -1; + } + tcp_ssl_t* tcp_ssl = NULL; + + int read_bytes = 0; + int total_bytes = 0; + static const size_t read_buf_size = 1024; + uint8_t read_buf[read_buf_size]; + + tcp_ssl = tcp_ssl_get(tcp); + if(tcp_ssl == NULL) { + return ERR_TCP_SSL_INVALID_CLIENTFD_DATA; + } + + if(p == NULL) { + return ERR_TCP_SSL_INVALID_DATA; + } + + // TCP_SSL_DEBUG("READY TO READ SOME DATA\n"); + + tcp_ssl->tcp_pbuf = p; + tcp_ssl->pbuf_offset = 0; + + do { + if(tcp_ssl->ssl_ctx.state != MBEDTLS_SSL_HANDSHAKE_OVER) { + TCP_SSL_DEBUG("start handshake: %d\n", tcp_ssl->ssl_ctx.state); + int ret = mbedtls_ssl_handshake(&tcp_ssl->ssl_ctx); + //handle_error(ret); + if(ret == 0) { + TCP_SSL_DEBUG("Protocol is %s Ciphersuite is %s\n", mbedtls_ssl_get_version(&tcp_ssl->ssl_ctx), mbedtls_ssl_get_ciphersuite(&tcp_ssl->ssl_ctx)); + + if(tcp_ssl->on_handshake) + tcp_ssl->on_handshake(tcp_ssl->arg, tcp_ssl->tcp, tcp_ssl); + } else if(ret != MBEDTLS_ERR_SSL_WANT_READ && ret != MBEDTLS_ERR_SSL_WANT_WRITE) { + TCP_SSL_DEBUG("handshake error: %d\n", ret); + + if(tcp_ssl->on_error) + tcp_ssl->on_error(tcp_ssl->arg, tcp_ssl->tcp, ret); + + break; + } + } else { + read_bytes = mbedtls_ssl_read(&tcp_ssl->ssl_ctx, &read_buf, read_buf_size); + TCP_SSL_DEBUG("tcp_ssl_read: read_bytes: %d, total_bytes: %d, tot_len: %d, pbuf_offset: %d\r\n", read_bytes, total_bytes, p->tot_len, tcp_ssl->pbuf_offset); + if(read_bytes < 0) { // SSL_OK + if(read_bytes == MBEDTLS_ERR_SSL_WANT_READ) { + break; + } else if(read_bytes != MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY) { + TCP_SSL_DEBUG("tcp_ssl_read: read error: %d\n", read_bytes); + } + total_bytes = read_bytes; + break; + } else if(read_bytes > 0){ + if(tcp_ssl->on_data){ + tcp_ssl->on_data(tcp_ssl->arg, tcp, read_buf, read_bytes); + } + total_bytes+= read_bytes; + } + } + } while (p->tot_len - tcp_ssl->pbuf_offset > 0 || read_bytes > 0); + + tcp_recved(tcp, p->tot_len); + tcp_ssl->tcp_pbuf = NULL; + pbuf_free(p); + + return total_bytes; +} + +int tcp_ssl_free(struct tcp_pcb *tcp) { + if(tcp == NULL) { + return -1; + } + tcp_ssl_t * item = tcp_ssl_array; + if(item->tcp == tcp){ + tcp_ssl_array = tcp_ssl_array->next; + if(item->tcp_pbuf != NULL) { + pbuf_free(item->tcp_pbuf); + } + mbedtls_ssl_free(&item->ssl_ctx); + mbedtls_ssl_config_free(&item->ssl_conf); + mbedtls_ctr_drbg_free(&item->drbg_ctx); + mbedtls_entropy_free(&item->entropy_ctx); + free(item); + return 0; + } + + while(item->next && item->next->tcp != tcp) + item = item->next; + + if(item->next == NULL){ + return ERR_TCP_SSL_INVALID_CLIENTFD_DATA;//item not found + } + tcp_ssl_t * i = item->next; + item->next = i->next; + if(i->tcp_pbuf != NULL){ + pbuf_free(i->tcp_pbuf); + } + mbedtls_ssl_free(&i->ssl_ctx); + mbedtls_ssl_config_free(&i->ssl_conf); + mbedtls_ctr_drbg_free(&i->drbg_ctx); + mbedtls_entropy_free(&i->entropy_ctx); + free(i); + + return 0; +} + +bool tcp_ssl_has(struct tcp_pcb *tcp) { + return tcp_ssl_get(tcp) != NULL; +} + +void tcp_ssl_arg(struct tcp_pcb *tcp, void * arg) { + tcp_ssl_t * item = tcp_ssl_get(tcp); + if(item) { + item->arg = arg; + } +} + +void tcp_ssl_data(struct tcp_pcb *tcp, tcp_ssl_data_cb_t arg){ + tcp_ssl_t * item = tcp_ssl_get(tcp); + if(item) { + item->on_data = arg; + } +} + +void tcp_ssl_handshake(struct tcp_pcb *tcp, tcp_ssl_handshake_cb_t arg){ + tcp_ssl_t * item = tcp_ssl_get(tcp); + if(item) { + item->on_handshake = arg; + } +} + +void tcp_ssl_err(struct tcp_pcb *tcp, tcp_ssl_error_cb_t arg){ + tcp_ssl_t * item = tcp_ssl_get(tcp); + if(item) { + item->on_error = arg; + } +} + +#endif // ASYNC_TCP_SSL_ENABLED diff --git a/src/tcp_mbedtls.h b/src/tcp_mbedtls.h new file mode 100644 index 00000000..492c70bb --- /dev/null +++ b/src/tcp_mbedtls.h @@ -0,0 +1,51 @@ +#ifndef LWIPR_MBEDTLS_H +#define LWIPR_MBEDTLS_H + +#if ASYNC_TCP_SSL_ENABLED + +#include "mbedtls/platform.h" +#include "mbedtls/net.h" +#include "mbedtls/debug.h" +#include "mbedtls/ssl.h" +#include "mbedtls/entropy.h" +#include "mbedtls/ctr_drbg.h" +#include "mbedtls/error.h" + +#ifdef __cplusplus +extern "C" { +#endif + +#define ERR_TCP_SSL_INVALID_SSL -101 +#define ERR_TCP_SSL_INVALID_TCP -102 +#define ERR_TCP_SSL_INVALID_CLIENTFD -103 +#define ERR_TCP_SSL_INVALID_CLIENTFD_DATA -104 +#define ERR_TCP_SSL_INVALID_DATA -105 + +struct tcp_pcb; +struct pbuf; +struct tcp_ssl_pcb; + +typedef void (* tcp_ssl_data_cb_t)(void *arg, struct tcp_pcb *tcp, uint8_t * data, size_t len); +typedef void (* tcp_ssl_handshake_cb_t)(void *arg, struct tcp_pcb *tcp, struct tcp_ssl_pcb* ssl); +typedef void (* tcp_ssl_error_cb_t)(void *arg, struct tcp_pcb *tcp, int8_t error); + +uint8_t tcp_ssl_has_client(); +int tcp_ssl_new_client(struct tcp_pcb *tcp, const char* hostname, const char* root_ca, const size_t root_ca_len); +int tcp_ssl_new_psk_client(struct tcp_pcb *tcp, const char* psk_ident, const char* psk); +int tcp_ssl_write(struct tcp_pcb *tcp, uint8_t *data, size_t len); +int tcp_ssl_read(struct tcp_pcb *tcp, struct pbuf *p); +int tcp_ssl_handshake_step(struct tcp_pcb *tcp); +int tcp_ssl_free(struct tcp_pcb *tcp); +bool tcp_ssl_has(struct tcp_pcb *tcp); +void tcp_ssl_arg(struct tcp_pcb *tcp, void * arg); +void tcp_ssl_data(struct tcp_pcb *tcp, tcp_ssl_data_cb_t arg); +void tcp_ssl_handshake(struct tcp_pcb *tcp, tcp_ssl_handshake_cb_t arg); +void tcp_ssl_err(struct tcp_pcb *tcp, tcp_ssl_error_cb_t arg); + +#ifdef __cplusplus +} +#endif + + +#endif // LWIPR_MBEDTLS_H +#endif // ASYNC_TCP_SSL_ENABLED