diff --git a/src/luasocket/usocket.c b/src/luasocket/usocket.c index 49e618c..55704ba 100644 --- a/src/luasocket/usocket.c +++ b/src/luasocket/usocket.c @@ -82,7 +82,6 @@ int socket_close(void) { \*-------------------------------------------------------------------------*/ void socket_destroy(p_socket ps) { if (*ps != SOCKET_INVALID) { - socket_setblocking(ps); close(*ps); *ps = SOCKET_INVALID; } diff --git a/src/ssl.c b/src/ssl.c index 934bb81..612f5a1 100644 --- a/src/ssl.c +++ b/src/ssl.c @@ -86,7 +86,8 @@ static int meth_destroy(lua_State *L) { p_ssl ssl = (p_ssl)luaL_checkudata(L, 1, "SSL:Connection"); if (ssl->state == LSEC_STATE_CONNECTED) { - socket_setblocking(&ssl->sock); + if (!ssl->shutdown) + socket_setblocking(&ssl->sock); SSL_shutdown(ssl->ssl); } if (ssl->sock != SOCKET_INVALID) { @@ -304,6 +305,7 @@ static int meth_create(lua_State *L) return luaL_argerror(L, 1, "invalid context"); } ssl->state = LSEC_STATE_NEW; + ssl->shutdown = 0; SSL_set_fd(ssl->ssl, (int)SOCKET_INVALID); SSL_set_mode(ssl->ssl, SSL_MODE_ENABLE_PARTIAL_WRITE | SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER); @@ -339,6 +341,56 @@ static int meth_receive(lua_State *L) { return buffer_meth_receive(L, &ssl->buf); } +/** + * SSL shutdown function + */ +static int meth_shutdown(lua_State *L) { + p_ssl ssl = (p_ssl)luaL_checkudata(L, 1, "SSL:Connection"); + int err; + p_timeout tm = timeout_markstart(&ssl->tm); + + ssl->shutdown = 1; + if (ssl->state == LSEC_STATE_CLOSED) { + lua_pushboolean(L, 1); + return 1; + } + + err = SSL_shutdown(ssl->ssl); + switch (err) { + case 0: + lua_pushboolean(L, 0); + lua_pushnil(L); + return 2; + case 1: + lua_pushboolean(L, 1); + lua_pushnil(L); + ssl->state = LSEC_STATE_CLOSED; + return 2; + case -1: + lua_pushboolean(L, 0); + ssl->error = SSL_get_error(ssl->ssl, err); + switch (ssl->error) { + case SSL_ERROR_WANT_READ: + err = socket_waitfd(&ssl->sock, WAITFD_R, tm); + lua_pushstring(L, ssl_ioerror((void *)ssl, err == IO_TIMEOUT ? LSEC_IO_SSL : err)); + break; + case SSL_ERROR_WANT_WRITE: + err = socket_waitfd(&ssl->sock, WAITFD_W, tm); + lua_pushstring(L, ssl_ioerror((void *)ssl, err == IO_TIMEOUT ? LSEC_IO_SSL : err)); + break; + case SSL_ERROR_SYSCALL: + if (ERR_peek_error()) + ssl->error = SSL_ERROR_SSL; + lua_pushstring(L, ssl_ioerror((void *)ssl, LSEC_IO_SSL)); + break; + default: + lua_pushstring(L, ssl_ioerror((void *)ssl, LSEC_IO_SSL)); + break; + } + return 2; + } +} + /** * Get the buffer's statistics. */ @@ -990,6 +1042,7 @@ static int meth_tlsa(lua_State *L) */ static luaL_Reg methods[] = { {"close", meth_close}, + {"shutdown", meth_shutdown}, {"getalpn", meth_getalpn}, {"getfd", meth_getfd}, {"getfinished", meth_getfinished}, diff --git a/src/ssl.h b/src/ssl.h index 2362ffe..d8ed219 100644 --- a/src/ssl.h +++ b/src/ssl.h @@ -32,6 +32,7 @@ typedef struct t_ssl_ { t_timeout tm; SSL *ssl; int state; + int shutdown; int error; } t_ssl; typedef t_ssl* p_ssl;