diff --git a/src/luasocket.c b/src/luasocket.c index c4f51bd..1f9780d 100644 --- a/src/luasocket.c +++ b/src/luasocket.c @@ -9,9 +9,9 @@ * of the IPv4 Socket layer available to Lua scripts. * The Lua interface to TCP/IP follows the BSD TCP/IP API closely, * trying to simplify all tasks involved in setting up a client connection -* and simple server connections. +* and server connections. * The provided IO routines, send and receive, follow the Lua style, being -* very similar to the read and write functions found in that language. +* very similar to the standard Lua read and write functions. * The module implements both a BSD bind and a Winsock2 bind, and has * been tested on several Unix flavors, as well as Windows 98 and NT. \*=========================================================================*/ @@ -19,16 +19,13 @@ /*=========================================================================*\ * Common include files \*=========================================================================*/ -#include #include #include -#include #include -#include +#include #include #include -#include #include "luasocket.h" @@ -161,6 +158,10 @@ static int global_toip(lua_State *L); static int global_tohostname(lua_State *L); static int global_udpsocket(lua_State *L); +#ifndef LUASOCKET_NOGLOBALS +static int global_calltable(lua_State *L); +#endif + /* luasocket table method API functions */ static int table_tcpaccept(lua_State *L); static int table_tcpsend(lua_State *L); @@ -187,6 +188,7 @@ static void tm_markstart(p_sock sock); /* I/O */ static int send_raw(p_sock sock, const char *data, int wanted, int *err); static int receive_raw(lua_State *L, p_sock sock, int wanted); +static int receive_word(lua_State *L, p_sock sock); static int receive_dosline(lua_State *L, p_sock sock); static int receive_unixline(lua_State *L, p_sock sock); static int receive_all(lua_State *L, p_sock sock); @@ -536,6 +538,21 @@ static int table_udpsendto(lua_State *L) } } +/*-------------------------------------------------------------------------*\ +* Global function that calls corresponding table method. +\*-------------------------------------------------------------------------*/ +#ifndef LUASOCKET_NOGLOBALS +int global_calltable(lua_State *L) +{ + p_tags tags = pop_tags(L); + if (lua_tag(L, 1) != tags->table) lua_error(L, "invalid socket object"); + lua_gettable(L, 1); + lua_insert(L, 1); + lua_call(L, lua_gettop(L)-1, LUA_MULTRET); + return lua_gettop(L); +} +#endif + /*-------------------------------------------------------------------------*\ * Waits for a set of sockets until a condition is met or timeout. * Lua Input: {input}, {output} [, timeout] @@ -549,103 +566,105 @@ static int table_udpsendto(lua_State *L) \*-------------------------------------------------------------------------*/ int global_select(lua_State *L) { - p_tags tags = pop_tags(L); - int ms = lua_isnil(L, 3) ? -1 : (int) (luaL_opt_number(L, 3, -1) * 1000); + p_tags tags = pop_tags(L); + int ms = lua_isnil(L, 3) ? -1 : (int) (luaL_opt_number(L, 3, -1) * 1000); fd_set readfds, *prfds = NULL, writefds, *pwfds = NULL; struct timeval tm, *ptm = NULL; - int ret, s, max = -1; - int byfds, canread, canwrite; - /* reset the file descriptor sets */ + int ret; + unsigned max = 0; + SOCKET s; + int byfds, canread, canwrite; + /* reset the file descriptor sets */ FD_ZERO(&readfds); FD_ZERO(&writefds); - /* all sockets, indexed by socket number, for internal use */ - lua_newtable(L); byfds = lua_gettop(L); - /* readable sockets table to be returned */ - lua_newtable(L); canread = lua_gettop(L); - /* writable sockets table to be returned */ - lua_newtable(L); canwrite = lua_gettop(L); - /* get sockets we will test for readability into fd_set */ - if (!lua_isnil(L, 1)) { - lua_pushnil(L); - while (lua_next(L, 1)) { - if (lua_tag(L, -1) == tags->table) { - p_sock sock = get_sock(L, -1, tags, NULL); - lua_pushnumber(L, sock->sock); - lua_pushvalue(L, -2); - lua_settable(L, byfds); - if (sock->sock > max) max = sock->sock; - /* a socket can have unread data in our internal buffer. in - * that case, we only call select to find out which of the - * other sockets can be written to or read from immediately. */ - if (!bf_isempty(sock)) { - ms = 0; - lua_pushnumber(L, lua_getn(L, canread) + 1); - lua_pushvalue(L, -2); - lua_settable(L, canread); - } else { - FD_SET(sock->sock, &readfds); - prfds = &readfds; - } - } - /* get rid of lua_next value and expose index */ - lua_pop(L, 1); - } - } - /* get sockets we will test for writability into fd_set */ - if (!lua_isnil(L, 2)) { - lua_pushnil(L); - while (lua_next(L, 2)) { - if (lua_tag(L, -1) == tags->table) { - p_sock sock = get_sock(L, -1, tags, NULL); - lua_pushnumber(L, sock->sock); - lua_pushvalue(L, -2); - lua_settable(L, byfds); - if (sock->sock > max) max = sock->sock; - FD_SET(sock->sock, &writefds); - pwfds = &writefds; - } - /* get rid of lua_next value and expose index */ - lua_pop(L, 1); - } - } - max++; - /* configure timeout value */ + /* all sockets, indexed by socket number, for internal use */ + lua_newtable(L); byfds = lua_gettop(L); + /* readable sockets table to be returned */ + lua_newtable(L); canread = lua_gettop(L); + /* writable sockets table to be returned */ + lua_newtable(L); canwrite = lua_gettop(L); + /* get sockets we will test for readability into fd_set */ + if (!lua_isnil(L, 1)) { + lua_pushnil(L); + while (lua_next(L, 1)) { + if (lua_tag(L, -1) == tags->table) { + p_sock sock = get_sock(L, -1, tags, NULL); + lua_pushnumber(L, sock->sock); + lua_pushvalue(L, -2); + lua_settable(L, byfds); + if (sock->sock > max) max = sock->sock; + /* a socket can have unread data in our internal buffer. in + * that case, we only call select to find out which of the + * other sockets can be written to or read from immediately. */ + if (!bf_isempty(sock)) { + ms = 0; + lua_pushnumber(L, lua_getn(L, canread) + 1); + lua_pushvalue(L, -2); + lua_settable(L, canread); + } else { + FD_SET(sock->sock, &readfds); + prfds = &readfds; + } + } + /* get rid of lua_next value and expose index */ + lua_pop(L, 1); + } + } + /* get sockets we will test for writability into fd_set */ + if (!lua_isnil(L, 2)) { + lua_pushnil(L); + while (lua_next(L, 2)) { + if (lua_tag(L, -1) == tags->table) { + p_sock sock = get_sock(L, -1, tags, NULL); + lua_pushnumber(L, sock->sock); + lua_pushvalue(L, -2); + lua_settable(L, byfds); + if (sock->sock > max) max = sock->sock; + FD_SET(sock->sock, &writefds); + pwfds = &writefds; + } + /* get rid of lua_next value and expose index */ + lua_pop(L, 1); + } + } + max++; + /* configure timeout value */ if (ms >= 0) { - ptm = &tm; /* ptm == NULL when we don't have timeout */ - /* fill timeval structure */ - tm.tv_sec = ms / 1000; - tm.tv_usec = (ms % 1000) * 1000; - } + ptm = &tm; /* ptm == NULL when we don't have timeout */ + /* fill timeval structure */ + tm.tv_sec = ms / 1000; + tm.tv_usec = (ms % 1000) * 1000; + } /* see if we can read, write or if we timedout */ ret = select(max, prfds, pwfds, NULL, ptm); - /* did we timeout? */ - if (ret <= 0 && ms > 0) { - push_error(L, NET_TIMEOUT); - return 3; - } - /* collect readable sockets */ - if (prfds) { - for (s = 0; s < max; s++) { - if (FD_ISSET(s, prfds)) { - lua_pushnumber(L, lua_getn(L, canread) + 1); - lua_pushnumber(L, s); - lua_gettable(L, byfds); - lua_settable(L, canread); - } - } - } - /* collect writable sockets */ - if (pwfds) { - for (s = 0; s < max; s++) { - if (FD_ISSET(s, pwfds)) { - lua_pushnumber(L, lua_getn(L, canwrite) + 1); - lua_pushnumber(L, s); - lua_gettable(L, byfds); - lua_settable(L, canwrite); - } - } - } - lua_pushnil(L); - return 3; + /* did we timeout? */ + if (ret <= 0 && ms > 0) { + push_error(L, NET_TIMEOUT); + return 3; + } + /* collect readable sockets */ + if (prfds) { + for (s = 0; s < max; s++) { + if (FD_ISSET(s, prfds)) { + lua_pushnumber(L, lua_getn(L, canread) + 1); + lua_pushnumber(L, s); + lua_gettable(L, byfds); + lua_settable(L, canread); + } + } + } + /* collect writable sockets */ + if (pwfds) { + for (s = 0; s < max; s++) { + if (FD_ISSET(s, pwfds)) { + lua_pushnumber(L, lua_getn(L, canwrite) + 1); + lua_pushnumber(L, s); + lua_gettable(L, byfds); + lua_settable(L, canwrite); + } + } + } + lua_pushnil(L); + return 3; } /*-------------------------------------------------------------------------*\ @@ -821,7 +840,7 @@ static int table_udpreceive(lua_State *L) \*-------------------------------------------------------------------------*/ static int table_tcpreceive(lua_State *L) { - static const char *const modenames[] = {"*l", "*lu", "*a", NULL}; + static const char *const modenames[] = {"*l", "*lu", "*a", "*w", NULL}; const char *mode; int err = NET_DONE; int arg; @@ -850,17 +869,13 @@ static int table_tcpreceive(lua_State *L) /* get next pattern */ switch (luaL_findstring(mode, modenames)) { /* DOS line mode */ - case 0: - err = receive_dosline(L, sock); - break; + case 0: err = receive_dosline(L, sock); break; /* Unix line mode */ - case 1: - err = receive_unixline(L, sock); - break; + case 1: err = receive_unixline(L, sock); break; /* until closed mode */ - case 2: - err = receive_all(L, sock); - break; + case 2: err = receive_all(L, sock); break; + /* word */ + case 3: err = receive_word(L, sock); break; /* else it is an error */ default: luaL_arg_check(L, 0, arg, "invalid receive pattern"); @@ -1431,7 +1446,7 @@ static int receive_all(lua_State *L, p_sock sock) \*-------------------------------------------------------------------------*/ static int receive_dosline(lua_State *L, p_sock sock) { - int got = 0; + int got, pos; const unsigned char *buffer = NULL; luaL_Buffer b; luaL_buffinit(L, &b); @@ -1441,26 +1456,21 @@ static int receive_dosline(lua_State *L, p_sock sock) return NET_TIMEOUT; } buffer = bf_receive(sock, &got); - if (got > 0) { - int len = 0, end = 1; - while (len < got) { - if (buffer[len] == '\n') { /* found eol */ - if (len > 0 && buffer[len-1] == '\r') { - end++; len--; - } - luaL_addlstring(&b, buffer, len); - bf_skip(sock, len + end); /* skip '\r\n' in stream */ - luaL_pushresult(&b); - return NET_DONE; - } - len++; - } - luaL_addlstring(&b, buffer, got); - bf_skip(sock, got); - } else { - luaL_pushresult(&b); + if (got <= 0) { + luaL_pushresult(&b); return NET_CLOSED; - } + } + pos = 0; + while (pos < got && buffer[pos] != '\n') { + /* we ignore all \r's */ + if (buffer[pos] != '\r') luaL_putchar(&b, buffer[pos]); + pos++; + } + if (pos < got) { + luaL_pushresult(&b); + bf_skip(sock, pos+1); /* skip '\n' too */ + return NET_DONE; + } else bf_skip(sock, pos); } } @@ -1475,7 +1485,7 @@ static int receive_dosline(lua_State *L, p_sock sock) \*-------------------------------------------------------------------------*/ static int receive_unixline(lua_State *L, p_sock sock) { - int got = 0; + int got, pos; const unsigned char *buffer = NULL; luaL_Buffer b; luaL_buffinit(L, &b); @@ -1485,23 +1495,75 @@ static int receive_unixline(lua_State *L, p_sock sock) return NET_TIMEOUT; } buffer = bf_receive(sock, &got); - if (got > 0) { - int len = 0; - while (len < got) { - if (buffer[len] == '\n') { /* found eol */ - luaL_addlstring(&b, buffer, len); - bf_skip(sock, len + 1); /* skip '\n' in stream */ - luaL_pushresult(&b); - return NET_DONE; - } - len++; - } - luaL_addlstring(&b, buffer, got); - bf_skip(sock, got); - } else { + if (got <= 0) { + luaL_pushresult(&b); + return NET_CLOSED; + } + pos = 0; + while (pos < got && buffer[pos] != '\n') pos++; + luaL_addlstring(&b, buffer, pos); + if (pos < got) { + luaL_pushresult(&b); + bf_skip(sock, pos+1); /* skip '\n' too */ + return NET_DONE; + } else bf_skip(sock, pos); + } +} + +/*-------------------------------------------------------------------------*\ +* Reads a word (maximal sequence of non--white-space characters), skipping +* white-spaces if needed. +* Input +* sock: socket structure being used in operation +* Result +* operation error code. NET_DONE, NET_TIMEOUT or NET_CLOSED +\*-------------------------------------------------------------------------*/ +static int receive_word(lua_State *L, p_sock sock) +{ + int pos, got; + const unsigned char *buffer = NULL; + luaL_Buffer b; + luaL_buffinit(L, &b); + /* skip leading white-spaces */ + for ( ;; ) { + if (bf_isempty(sock) && tm_timedout(sock, TM_RECEIVE)) { + lua_pushstring(L, ""); + return NET_TIMEOUT; + } + buffer = bf_receive(sock, &got); + if (got <= 0) { + lua_pushstring(L, ""); + return NET_CLOSED; + } + pos = 0; + while (pos < got && isspace(buffer[pos])) pos++; + bf_skip(sock, pos); + if (pos < got) { + buffer += pos; + got -= pos; + pos = 0; + break; + } + } + /* capture word */ + for ( ;; ) { + while (pos < got && !isspace(buffer[pos])) pos++; + luaL_addlstring(&b, buffer, pos); + bf_skip(sock, pos); + if (pos < got) { + luaL_pushresult(&b); + return NET_DONE; + } + if (bf_isempty(sock) && tm_timedout(sock, TM_RECEIVE)) { + luaL_pushresult(&b); + return NET_TIMEOUT; + } + buffer = bf_receive(sock, &got); + if (got <= 0) { luaL_pushresult(&b); return NET_CLOSED; } + pos = 0; } } @@ -1514,7 +1576,7 @@ static int receive_unixline(lua_State *L, p_sock sock) \*-------------------------------------------------------------------------*/ void lua_socketlibopen(lua_State *L) { - static struct luaL_reg funcs[] = { + struct luaL_reg funcs[] = { {"bind", global_tcpbind}, {"connect", global_tcpconnect}, {"select", global_select}, @@ -1552,6 +1614,22 @@ void lua_socketlibopen(lua_State *L) lua_pushcfunction(L, global_sleep); lua_setglobal(L, "sleep"); lua_pushcfunction(L, global_time); lua_setglobal(L, "time"); #endif +#ifndef LUASOCKET_NOGLOBALS + { + char *global[] = { + "accept", "close", "getpeername", + "getsockname", "receive", "send", + "receivefrom", "sendto" + }; + unsigned int i; + for (i = 0; i < sizeof(global)/sizeof(char *); i++) { + lua_pushstring(L, global[i]); + lua_pushuserdata(L, tags); + lua_pushcclosure(L, global_calltable, 2); + lua_setglobal(L, global[i]); + } + } +#endif } /*=========================================================================*\ @@ -1583,7 +1661,7 @@ static p_sock push_clienttable(lua_State *L, p_tags tags) if (!sock) return NULL; lua_settag(L, tags->client); lua_settable(L, -3); - sock->sock = -1; + sock->sock = INVALID_SOCKET; sock->is_connected = 0; sock->tm_block = -1; sock->tm_return = -1;