/*-------------------------------------------------------------------------- * LuaSec 0.2 * Copyright (C) 2006-2007 Bruno Silvestre * *--------------------------------------------------------------------------*/ #include #include #include #include #include #include #include "io.h" #include "buffer.h" #include "timeout.h" #include "socket.h" #include "context.h" #include "ssl.h" /** * Map error code into string. */ static const char *ssl_ioerror(void *ctx, int err) { if (err == IO_SSL) { p_ssl ssl = (p_ssl) ctx; switch(ssl->error) { case SSL_ERROR_NONE: return "No error"; case SSL_ERROR_ZERO_RETURN: return "closed"; case SSL_ERROR_WANT_READ: return "wantread"; case SSL_ERROR_WANT_WRITE: return "wantwrite"; case SSL_ERROR_WANT_CONNECT: return "'connect' not completed"; case SSL_ERROR_WANT_ACCEPT: return "'accept' not completed"; case SSL_ERROR_WANT_X509_LOOKUP: return "Waiting for callback"; case SSL_ERROR_SYSCALL: return "System error"; case SSL_ERROR_SSL: return ERR_reason_error_string(ERR_get_error()); default: return "Unknown SSL error"; } } return socket_strerror(err); } /** * Close the connection before the GC collect the object. */ static int meth_destroy(lua_State *L) { p_ssl ssl = (p_ssl) lua_touserdata(L, 1); if (ssl->ssl) { socket_setblocking(&ssl->sock); SSL_shutdown(ssl->ssl); socket_destroy(&ssl->sock); SSL_free(ssl->ssl); ssl->ssl = NULL; } return 0; } /** * Perform the TLS/SSL handshake */ static int handshake(p_ssl ssl) { p_timeout tm = timeout_markstart(&ssl->tm); if (ssl->state == ST_SSL_CLOSED) return IO_CLOSED; for ( ; ; ) { int err = SSL_do_handshake(ssl->ssl); ssl->error = SSL_get_error(ssl->ssl, err); switch(ssl->error) { case SSL_ERROR_NONE: ssl->state = ST_SSL_CONNECTED; return IO_DONE; case SSL_ERROR_WANT_READ: err = socket_waitfd(&ssl->sock, WAITFD_R, tm); if (err == IO_TIMEOUT) return IO_SSL; if (err != IO_DONE) return err; break; case SSL_ERROR_WANT_WRITE: err = socket_waitfd(&ssl->sock, WAITFD_W, tm); if (err == IO_TIMEOUT) return IO_SSL; if (err != IO_DONE) return err; break; case SSL_ERROR_SYSCALL: if (ERR_peek_error()) { ssl->error = SSL_ERROR_SSL; return IO_SSL; } if (err == 0) return IO_CLOSED; return errno; default: return IO_SSL; } } return IO_UNKNOWN; } /** * Send data */ static int ssl_send(void *ctx, const char *data, size_t count, size_t *sent, p_timeout tm) { p_ssl ssl = (p_ssl) ctx; if (ssl->state == ST_SSL_CLOSED) return IO_CLOSED; *sent = 0; for ( ; ; ) { int err = SSL_write(ssl->ssl, data, (int) count); ssl->error = SSL_get_error(ssl->ssl, err); switch(ssl->error) { case SSL_ERROR_NONE: *sent = err; return IO_DONE; case SSL_ERROR_WANT_READ: err = socket_waitfd(&ssl->sock, WAITFD_R, tm); if (err == IO_TIMEOUT) return IO_SSL; if (err != IO_DONE) return err; break; case SSL_ERROR_WANT_WRITE: err = socket_waitfd(&ssl->sock, WAITFD_W, tm); if (err == IO_TIMEOUT) return IO_SSL; if (err != IO_DONE) return err; break; case SSL_ERROR_SYSCALL: if (ERR_peek_error()) { ssl->error = SSL_ERROR_SSL; return IO_SSL; } if (err == 0) return IO_CLOSED; return errno; default: return IO_SSL; } } return IO_UNKNOWN; } /** * Receive data */ static int ssl_recv(void *ctx, char *data, size_t count, size_t *got, p_timeout tm) { p_ssl ssl = (p_ssl) ctx; if (ssl->state == ST_SSL_CLOSED) return IO_CLOSED; *got = 0; for ( ; ; ) { int err = SSL_read(ssl->ssl, data, (int) count); ssl->error = SSL_get_error(ssl->ssl, err); switch(ssl->error) { case SSL_ERROR_NONE: *got = err; return IO_DONE; case SSL_ERROR_WANT_READ: err = socket_waitfd(&ssl->sock, WAITFD_R, tm); if (err == IO_TIMEOUT) return IO_SSL; if (err != IO_DONE) return err; break; case SSL_ERROR_WANT_WRITE: err = socket_waitfd(&ssl->sock, WAITFD_W, tm); if (err == IO_TIMEOUT) return IO_SSL; if (err != IO_DONE) return err; break; case SSL_ERROR_SYSCALL: if (ERR_peek_error()) { ssl->error = SSL_ERROR_SSL; return IO_SSL; } if (err == 0) return IO_CLOSED; return errno; default: return IO_SSL; } } return IO_UNKNOWN; } /** * Create a new TLS/SSL object and mark it as new. */ static int meth_create(lua_State *L) { p_ssl ssl; int mode = ctx_getmode(L, 1); SSL_CTX *ctx = ctx_getcontext(L, 1); if (mode == MD_CTX_INVALID) { lua_pushnil(L); lua_pushstring(L, "invalid mode"); return 2; } ssl = (p_ssl) lua_newuserdata(L, sizeof(t_ssl)); if (!ssl) { lua_pushnil(L); lua_pushstring(L, "error creating SSL object"); return 2; } ssl->ssl = SSL_new(ctx); if (!ssl->ssl) { lua_pushnil(L); lua_pushstring(L, "error creating SSL object"); return 2;; } ssl->state = ST_SSL_NEW; SSL_set_fd(ssl->ssl, (int) SOCKET_INVALID); SSL_set_mode(ssl->ssl, SSL_MODE_ENABLE_PARTIAL_WRITE | SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER); if (mode == MD_CTX_SERVER) SSL_set_accept_state(ssl->ssl); else SSL_set_connect_state(ssl->ssl); io_init(&ssl->io, (p_send) ssl_send, (p_recv) ssl_recv, (p_error) ssl_ioerror, ssl); timeout_init(&ssl->tm, -1, -1); buffer_init(&ssl->buf, &ssl->io, &ssl->tm); luaL_getmetatable(L, "SSL:Connection"); lua_setmetatable(L, -2); return 1; } /** * Buffer send function */ static int meth_send(lua_State *L) { p_ssl ssl = (p_ssl) luaL_checkudata(L, 1, "SSL:Connection"); return buffer_meth_send(L, &ssl->buf); } /** * Buffer receive function */ static int meth_receive(lua_State *L) { p_ssl ssl = (p_ssl) luaL_checkudata(L, 1, "SSL:Connection"); return buffer_meth_receive(L, &ssl->buf); } /** * Select support methods */ static int meth_getfd(lua_State *L) { p_ssl ssl = (p_ssl) luaL_checkudata(L, 1, "SSL:Connection"); lua_pushnumber(L, ssl->sock); return 1; } /** * Set the TLS/SSL file descriptor. * This is done *before* the handshake. */ static int meth_setfd(lua_State *L) { p_ssl ssl = (p_ssl) luaL_checkudata(L, 1, "SSL:Connection"); if (ssl->state != ST_SSL_NEW) luaL_argerror(L, 1, "invalid SSL object state"); ssl->sock = luaL_checkint(L, 2); socket_setnonblocking(&ssl->sock); SSL_set_fd(ssl->ssl, (int)ssl->sock); return 0; } /** * Lua handshake function. */ static int meth_handshake(lua_State *L) { p_ssl ssl = (p_ssl) luaL_checkudata(L, 1, "SSL:Connection"); int err = handshake(ssl); if (err == IO_DONE) { lua_pushboolean(L, 1); return 1; } lua_pushboolean(L, 0); lua_pushstring(L, ssl_ioerror((void*)ssl, err)); return 2; } /** * Close the connection. */ static int meth_close(lua_State *L) { p_ssl ssl = (p_ssl) luaL_checkudata(L, 1, "SSL:Connection"); meth_destroy(L); ssl->state = ST_SSL_CLOSED; return 0; } /** * Set timeout. */ static int meth_settimeout(lua_State *L) { p_ssl ssl = (p_ssl) luaL_checkudata(L, 1, "SSL:Connection"); return timeout_meth_settimeout(L, &ssl->tm); } /** * Check if there is data in the buffer. */ static int meth_dirty(lua_State *L) { int res = 0; p_ssl ssl = (p_ssl) luaL_checkudata(L, 1, "SSL:Connection"); if (ssl->state != ST_SSL_CLOSED) res = !buffer_isempty(&ssl->buf) || SSL_pending(ssl->ssl); lua_pushboolean(L, res); return 1; } /** * Return the state information about the SSL object. */ static int meth_want(lua_State *L) { p_ssl ssl = (p_ssl) luaL_checkudata(L, 1, "SSL:Connection"); int code = (ssl->state == ST_SSL_CLOSED) ? SSL_NOTHING : SSL_want(ssl->ssl); switch(code) { case SSL_NOTHING: lua_pushstring(L, "nothing"); break; case SSL_READING: lua_pushstring(L, "read"); break; case SSL_WRITING: lua_pushstring(L, "write"); break; case SSL_X509_LOOKUP: lua_pushstring(L, "x509lookup"); break; } return 1; } /*---------------------------------------------------------------------------*/ /** * SSL metamethods */ static luaL_Reg meta[] = { {"close", meth_close}, {"getfd", meth_getfd}, {"dirty", meth_dirty}, {"dohandshake", meth_handshake}, {"receive", meth_receive}, {"send", meth_send}, {"settimeout", meth_settimeout}, {"want", meth_want}, {NULL, NULL} }; /** * SSL functions */ static luaL_Reg funcs[] = { {"create", meth_create}, {"setfd", meth_setfd}, {NULL, NULL} }; /** * Initialize modules */ LUASEC_API int luaopen_ssl_core(lua_State *L) { /* Initialize SSL */ if (!SSL_library_init()) { lua_pushstring(L, "unable to initialize SSL library"); lua_error(L); } SSL_load_error_strings(); /* Initialize internal library */ socket_open(); /* Registre the functions and tables */ luaL_newmetatable(L, "SSL:Connection"); lua_newtable(L); luaL_register(L, NULL, meta); lua_setfield(L, -2, "__index"); lua_pushcfunction(L, meth_destroy); lua_setfield(L, -2, "__gc"); luaL_register(L, "ssl.core", funcs); lua_pushnumber(L, SOCKET_INVALID); lua_setfield(L, -2, "invalidfd"); return 1; }