From bccb1f74f06a0667d53d8db52fdc81dc7f3a5c5c Mon Sep 17 00:00:00 2001 From: Bruno Silvestre Date: Thu, 8 May 2025 15:50:30 -0300 Subject: [PATCH] Fix IO_DONE error, remove shutdown field --- src/ssl.c | 114 ++++++++++++++++++++++++++++-------------------------- src/ssl.h | 1 - 2 files changed, 59 insertions(+), 56 deletions(-) diff --git a/src/ssl.c b/src/ssl.c index 1de15be..f104ac6 100644 --- a/src/ssl.c +++ b/src/ssl.c @@ -11,6 +11,9 @@ #if defined(WIN32) #include +#define LSEC_ERR_INPROGRESS WSAEINPROGRESS +#else +#define LSEC_ERR_INPROGRESS EINPROGRESS #endif #include @@ -86,8 +89,7 @@ static int meth_destroy(lua_State *L) { p_ssl ssl = (p_ssl)luaL_checkudata(L, 1, "SSL:Connection"); if (ssl->state == LSEC_STATE_CONNECTED) { - if (!ssl->shutdown) - socket_setblocking(&ssl->sock); + socket_setblocking(&ssl->sock); SSL_shutdown(ssl->ssl); } if (ssl->sock != SOCKET_INVALID) { @@ -153,6 +155,46 @@ static int handshake(p_ssl ssl) return IO_UNKNOWN; } +/** + * Perform the TLS/SSL shutdown + */ +static int low_shutdown(p_ssl ssl) +{ + int err; + p_timeout tm = timeout_markstart(&ssl->tm); + if (ssl->state == LSEC_STATE_CLOSED) + return IO_CLOSED; + for ( ; ; ) { + ERR_clear_error(); + err = SSL_shutdown(ssl->ssl); + if (err == 0) return LSEC_ERR_INPROGRESS; + if (err == 1) return IO_DONE; + ssl->error = SSL_get_error(ssl->ssl, err); + switch (ssl->error) { + case SSL_ERROR_WANT_READ: + err = socket_waitfd(&ssl->sock, WAITFD_R, tm); + if (err == IO_TIMEOUT) return LSEC_IO_SSL; + if (err != IO_DONE) return err; + break; + case SSL_ERROR_WANT_WRITE: + err = socket_waitfd(&ssl->sock, WAITFD_W, tm); + if (err == IO_TIMEOUT) return LSEC_IO_SSL; + if (err != IO_DONE) return err; + break; + case SSL_ERROR_SYSCALL: + if (ERR_peek_error()) { + ssl->error = SSL_ERROR_SSL; + return LSEC_IO_SSL; + } + if (err == 0) return IO_CLOSED; + return lsec_socket_error(); + default: + return LSEC_IO_SSL; + } + } + return IO_UNKNOWN; +} + /** * Send data */ @@ -305,7 +347,6 @@ static int meth_create(lua_State *L) return luaL_argerror(L, 1, "invalid context"); } ssl->state = LSEC_STATE_NEW; - ssl->shutdown = 0; SSL_set_fd(ssl->ssl, (int)SOCKET_INVALID); SSL_set_mode(ssl->ssl, SSL_MODE_ENABLE_PARTIAL_WRITE | SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER); @@ -341,58 +382,6 @@ static int meth_receive(lua_State *L) { return buffer_meth_receive(L, &ssl->buf); } -/** - * SSL shutdown function - */ -static int meth_shutdown(lua_State *L) { - p_ssl ssl = (p_ssl)luaL_checkudata(L, 1, "SSL:Connection"); - int err; - p_timeout tm = timeout_markstart(&ssl->tm); - - ssl->shutdown = 1; - if (ssl->state == LSEC_STATE_CLOSED) { - lua_pushboolean(L, 1); - return 1; - } - - err = SSL_shutdown(ssl->ssl); - switch (err) { - case 0: - lua_pushboolean(L, 0); - lua_pushnil(L); - return 2; - case 1: - lua_pushboolean(L, 1); - lua_pushnil(L); - ssl->state = LSEC_STATE_CLOSED; - return 2; - default: - lua_pushboolean(L, 0); - ssl->error = SSL_get_error(ssl->ssl, err); - switch (ssl->error) { - case SSL_ERROR_WANT_READ: - err = socket_waitfd(&ssl->sock, WAITFD_R, tm); - lua_pushstring(L, ssl_ioerror((void *)ssl, err == IO_TIMEOUT ? LSEC_IO_SSL : err)); - break; - case SSL_ERROR_WANT_WRITE: - err = socket_waitfd(&ssl->sock, WAITFD_W, tm); - lua_pushstring(L, ssl_ioerror((void *)ssl, err == IO_TIMEOUT ? LSEC_IO_SSL : err)); - break; - case SSL_ERROR_SYSCALL: - if (ERR_peek_error()) - ssl->error = SSL_ERROR_SSL; - lua_pushstring(L, ssl_ioerror((void *)ssl, LSEC_IO_SSL)); - break; - default: - lua_pushstring(L, ssl_ioerror((void *)ssl, LSEC_IO_SSL)); - break; - } - return 2; - } - // unreachable - return 0; -} - /** * Get the buffer's statistics. */ @@ -461,6 +450,21 @@ static int meth_handshake(lua_State *L) return 2; } +/** + * Lua shutdown function. + */ +static int meth_shutdown(lua_State *L) { + p_ssl ssl = (p_ssl)luaL_checkudata(L, 1, "SSL:Connection"); + int err = low_shutdown(ssl); + if (err == IO_DONE) { + lua_pushboolean(L, 1); + return 1; + } + lua_pushboolean(L, 0); + lua_pushstring(L, (err == LSEC_ERR_INPROGRESS) ? "inprogress" : ssl_ioerror((void*)ssl, err)); + return 2; +} + /** * Close the connection. */ diff --git a/src/ssl.h b/src/ssl.h index d8ed219..2362ffe 100644 --- a/src/ssl.h +++ b/src/ssl.h @@ -32,7 +32,6 @@ typedef struct t_ssl_ { t_timeout tm; SSL *ssl; int state; - int shutdown; int error; } t_ssl; typedef t_ssl* p_ssl;