diff --git a/NEW b/NEW index 749641a..879f12c 100644 --- a/NEW +++ b/NEW @@ -1,5 +1,20 @@ -Socket structures are independent -UDPBUFFERSIZE is now internal -Better treatment of closed connections: test!!! -HTTP post now deals with 1xx codes -connect, bind etc only try first address returned by resolver +All functions provided by the library are in the namespace "socket". +Functions such as send/receive/timeout/close etc do not exist in the +namespace. They are now only available as methods of the appropriate +objects. + +Object has been changed to become more uniform. First create an object for +a given domain/family and protocol. Then connect or bind if needed. Then +use IO functions. + +All functions return a non-nil value as first return value if successful. +All functions return nil followed by error message in case of error. +WARNING: The send function was affected. + +Better error messages and parameter checking. + +UDP connected udp sockets can break association with peer by calling +setpeername with address "*". + +socket.sleep and socket.time are now part of the library and are +supported. diff --git a/etc/tftp.lua b/etc/tftp.lua new file mode 100644 index 0000000..a0db68e --- /dev/null +++ b/etc/tftp.lua @@ -0,0 +1,131 @@ +----------------------------------------------------------------------------- +-- TFTP support for the Lua language +-- LuaSocket 1.5 toolkit. +-- Author: Diego Nehab +-- Conforming to: RFC 783, LTN7 +-- RCS ID: $Id$ +----------------------------------------------------------------------------- + +local Public, Private = {}, {} +local socket = _G[LUASOCKET_LIBNAME] -- get LuaSocket namespace +socket.tftp = Public -- create tftp sub namespace + +----------------------------------------------------------------------------- +-- Program constants +----------------------------------------------------------------------------- +local char = string.char +local byte = string.byte + +Public.PORT = 69 +Private.OP_RRQ = 1 +Private.OP_WRQ = 2 +Private.OP_DATA = 3 +Private.OP_ACK = 4 +Private.OP_ERROR = 5 +Private.OP_INV = {"RRQ", "WRQ", "DATA", "ACK", "ERROR"} + +----------------------------------------------------------------------------- +-- Packet creation functions +----------------------------------------------------------------------------- +function Private.RRQ(source, mode) + return char(0, Private.OP_RRQ) .. source .. char(0) .. mode .. char(0) +end + +function Private.WRQ(source, mode) + return char(0, Private.OP_RRQ) .. source .. char(0) .. mode .. char(0) +end + +function Private.ACK(block) + local low, high + low = math.mod(block, 256) + high = (block - low)/256 + return char(0, Private.OP_ACK, high, low) +end + +function Private.get_OP(dgram) + local op = byte(dgram, 1)*256 + byte(dgram, 2) + return op +end + +----------------------------------------------------------------------------- +-- Packet analysis functions +----------------------------------------------------------------------------- +function Private.split_DATA(dgram) + local block = byte(dgram, 3)*256 + byte(dgram, 4) + local data = string.sub(dgram, 5) + return block, data +end + +function Private.get_ERROR(dgram) + local code = byte(dgram, 3)*256 + byte(dgram, 4) + local msg + _,_, msg = string.find(dgram, "(.*)\000", 5) + return string.format("error code %d: %s", code, msg) +end + +----------------------------------------------------------------------------- +-- Downloads and returns a file pointed to by url +----------------------------------------------------------------------------- +function Public.get(url) + local parsed = socket.url.parse(url, { + host = "", + port = Public.PORT, + path ="/", + scheme = "tftp" + }) + if parsed.scheme ~= "tftp" then + return nil, string.format("unknown scheme '%s'", parsed.scheme) + end + local retries, dgram, sent, datahost, dataport, code + local cat = socket.concat.create() + local last = 0 + local udp, err = socket.udp() + if not udp then return nil, err end + -- convert from name to ip if needed + parsed.host = socket.toip(parsed.host) + udp:timeout(1) + -- first packet gives data host/port to be used for data transfers + retries = 0 + repeat + sent, err = udp:sendto(Private.RRQ(parsed.path, "octet"), + parsed.host, parsed.port) + if err then return nil, err end + dgram, datahost, dataport = udp:receivefrom() + retries = retries + 1 + until dgram or datahost ~= "timeout" or retries > 5 + if not dgram then return nil, datahost end + -- associate socket with data host/port + udp:setpeername(datahost, dataport) + -- process all data packets + while 1 do + -- decode packet + code = Private.get_OP(dgram) + if code == Private.OP_ERROR then + return nil, Private.get_ERROR(dgram) + end + if code ~= Private.OP_DATA then + return nil, "unhandled opcode " .. code + end + -- get data packet parts + local block, data = Private.split_DATA(dgram) + -- if not repeated, write + if block == last+1 then + cat:addstring(data) + last = block + end + -- last packet brings less than 512 bytes of data + if string.len(data) < 512 then + sent, err = udp:send(Private.ACK(block)) + return cat:getresult() + end + -- get the next packet + retries = 0 + repeat + sent, err = udp:send(Private.ACK(last)) + if err then return err end + dgram, err = udp:receive() + retries = retries + 1 + until dgram or err ~= "timeout" or retries > 5 + if not dgram then return err end + end +end diff --git a/samples/daytimeclnt.lua b/samples/daytimeclnt.lua index 000dfd5..4debc81 100644 --- a/samples/daytimeclnt.lua +++ b/samples/daytimeclnt.lua @@ -7,7 +7,8 @@ end host = socket.toip(host) udp = socket.udp() print("Using host '" ..host.. "' and port " ..port.. "...") -err = udp:sendto("anything", host, port) +udp:setpeername(host, port) +sent, err = udp:send("anything") if err then print(err) exit() end dgram, err = udp:receive() if not dgram then print(err) exit() end diff --git a/samples/talker.lua b/samples/talker.lua index d66cf66..688824f 100644 --- a/samples/talker.lua +++ b/samples/talker.lua @@ -5,18 +5,18 @@ if arg then port = arg[2] or port end print("Attempting connection to host '" ..host.. "' and port " ..port.. "...") -c, e = connect(host, port) +c, e = socket.connect(host, port) if not c then print(e) - exit() + os.exit() end print("Connected! Please type stuff (empty line to stop):") -l = read() +l = io.read() while l and l ~= "" and not e do - e = c:send(l, "\n") + t, e = c:send(l, "\n") if e then print(e) - exit() + os.exit() end - l = read() + l = io.read() end diff --git a/src/auxiliar.c b/src/auxiliar.c new file mode 100644 index 0000000..5e5ba1a --- /dev/null +++ b/src/auxiliar.c @@ -0,0 +1,113 @@ +/*=========================================================================*\ +* Auxiliar routines for class hierarchy manipulation +* +* RCS ID: $Id$ +\*=========================================================================*/ +#include "aux.h" + +/*=========================================================================*\ +* Internal function prototypes +\*=========================================================================*/ +static void *aux_getgroupudata(lua_State *L, const char *group, int objidx); + +/*=========================================================================*\ +* Exported functions +\*=========================================================================*/ +/*-------------------------------------------------------------------------*\ +* Creates a new class. A class has methods given by the func array and the +* field 'class' tells the object class. The table 'group' list the class +* groups the object belongs to. +\*-------------------------------------------------------------------------*/ +void aux_newclass(lua_State *L, const char *name, luaL_reg *func) +{ + luaL_newmetatable(L, name); + lua_pushstring(L, "__index"); + lua_newtable(L); + luaL_openlib(L, NULL, func, 0); + lua_pushstring(L, "class"); + lua_pushstring(L, name); + lua_settable(L, -3); + lua_settable(L, -3); + lua_pushstring(L, "group"); + lua_newtable(L); + lua_settable(L, -3); + lua_pop(L, 1); +} + +/*-------------------------------------------------------------------------*\ +* Add group to object list of groups. +\*-------------------------------------------------------------------------*/ +void aux_add2group(lua_State *L, const char *name, const char *group) +{ + luaL_getmetatable(L, name); + lua_pushstring(L, "group"); + lua_gettable(L, -2); + lua_pushstring(L, group); + lua_pushnumber(L, 1); + lua_settable(L, -3); + lua_pop(L, 2); +} + +/*-------------------------------------------------------------------------*\ +* Get a userdata making sure the object belongs to a given class. +\*-------------------------------------------------------------------------*/ +void *aux_checkclass(lua_State *L, const char *name, int objidx) +{ + void *data = luaL_checkudata(L, objidx, name); + if (!data) { + char msg[45]; + sprintf(msg, "%.35s expected", name); + luaL_argerror(L, objidx, msg); + } + return data; +} + +/*-------------------------------------------------------------------------*\ +* Get a userdata making sure the object belongs to a given group. +\*-------------------------------------------------------------------------*/ +void *aux_checkgroup(lua_State *L, const char *group, int objidx) +{ + void *data = aux_getgroupudata(L, group, objidx); + if (!data) { + char msg[45]; + sprintf(msg, "%.35s expected", group); + luaL_argerror(L, objidx, msg); + } + return data; +} + +/*-------------------------------------------------------------------------*\ +* Set object class. +\*-------------------------------------------------------------------------*/ +void aux_setclass(lua_State *L, const char *name, int objidx) +{ + luaL_getmetatable(L, name); + if (objidx < 0) objidx--; + lua_setmetatable(L, objidx); +} + +/*=========================================================================*\ +* Internal functions +\*=========================================================================*/ +/*-------------------------------------------------------------------------*\ +* Get a userdata if object belongs to a given group. +\*-------------------------------------------------------------------------*/ +static void *aux_getgroupudata(lua_State *L, const char *group, int objidx) +{ + if (!lua_getmetatable(L, objidx)) return NULL; + lua_pushstring(L, "group"); + lua_gettable(L, -2); + if (lua_isnil(L, -1)) { + lua_pop(L, 2); + return NULL; + } + lua_pushstring(L, group); + lua_gettable(L, -2); + if (lua_isnil(L, -1)) { + lua_pop(L, 3); + return NULL; + } + lua_pop(L, 3); + return lua_touserdata(L, objidx); +} + diff --git a/src/auxiliar.h b/src/auxiliar.h new file mode 100644 index 0000000..2681a84 --- /dev/null +++ b/src/auxiliar.h @@ -0,0 +1,26 @@ +/*=========================================================================*\ +* Auxiliar routines for class hierarchy manipulation +* +* RCS ID: $Id$ +\*=========================================================================*/ +#ifndef AUX_H +#define AUX_H + +#include +#include + +void aux_newclass(lua_State *L, const char *name, luaL_reg *func); +void aux_add2group(lua_State *L, const char *name, const char *group); +void *aux_checkclass(lua_State *L, const char *name, int objidx); +void *aux_checkgroup(lua_State *L, const char *group, int objidx); +void aux_setclass(lua_State *L, const char *name, int objidx); + +/* min and max macros */ +#ifndef MIN +#define MIN(x, y) ((x) < (y) ? x : y) +#endif +#ifndef MAX +#define MAX(x, y) ((x) > (y) ? x : y) +#endif + +#endif diff --git a/src/buffer.c b/src/buffer.c index 73df8b3..c5ef66c 100644 --- a/src/buffer.c +++ b/src/buffer.c @@ -1,28 +1,24 @@ /*=========================================================================*\ * Buffered input/output routines -* Lua methods: -* send: unbuffered send using C base_send -* receive: buffered read using C base_receive * * RCS ID: $Id$ \*=========================================================================*/ #include #include -#include "lsbuf.h" +#include "error.h" +#include "aux.h" +#include "buf.h" /*=========================================================================*\ -* Internal function prototypes. +* Internal function prototypes \*=========================================================================*/ -static int sendraw(lua_State *L, p_buf buf, cchar *data, size_t len, - size_t *done); static int recvraw(lua_State *L, p_buf buf, size_t wanted); -static int recvdosline(lua_State *L, p_buf buf); -static int recvunixline(lua_State *L, p_buf buf); +static int recvline(lua_State *L, p_buf buf); static int recvall(lua_State *L, p_buf buf); - -static int buf_contents(lua_State *L, p_buf buf, cchar **data, size_t *len); -static void buf_skip(lua_State *L, p_buf buf, size_t len); +static int buf_get(p_buf buf, const char **data, size_t *count); +static void buf_skip(p_buf buf, size_t count); +static int sendraw(p_buf buf, const char *data, size_t count, size_t *sent); /*=========================================================================*\ * Exported functions @@ -37,98 +33,69 @@ void buf_open(lua_State *L) /*-------------------------------------------------------------------------*\ * Initializes C structure -* Input -* buf: buffer structure to initialize -* base: socket object to associate with buffer structure \*-------------------------------------------------------------------------*/ -void buf_init(lua_State *L, p_buf buf, p_base base) +void buf_init(p_buf buf, p_io io, p_tm tm) { - (void) L; - buf->buf_first = buf->buf_last = 0; - buf->buf_base = base; + buf->first = buf->last = 0; + buf->io = io; + buf->tm = tm; } /*-------------------------------------------------------------------------*\ * Send data through buffered object -* Input -* buf: buffer structure to be used -* Lua Input: self, a_1 [, a_2, a_3 ... a_n] -* self: socket object -* a_i: strings to be sent. -* Lua Returns -* On success: nil, followed by the total number of bytes sent -* On error: error message \*-------------------------------------------------------------------------*/ -int buf_send(lua_State *L, p_buf buf) +int buf_meth_send(lua_State *L, p_buf buf) { int top = lua_gettop(L); size_t total = 0; - int err = PRIV_DONE; - int arg; - p_base base = buf->buf_base; - tm_markstart(&base->base_tm); + int arg, err = IO_DONE; + p_tm tm = buf->tm; + tm_markstart(tm); for (arg = 2; arg <= top; arg++) { /* first arg is socket object */ - size_t done, len; - cchar *data = luaL_optlstring(L, arg, NULL, &len); - if (!data || err != PRIV_DONE) break; - err = sendraw(L, buf, data, len, &done); - total += done; + size_t sent, count; + const char *data = luaL_optlstring(L, arg, NULL, &count); + if (!data || err != IO_DONE) break; + err = sendraw(buf, data, count, &sent); + total += sent; } - priv_pusherror(L, err); lua_pushnumber(L, total); + error_push(L, err); #ifdef LUASOCKET_DEBUG /* push time elapsed during operation as the last return value */ - lua_pushnumber(L, tm_getelapsed(&base->base_tm)/1000.0); + lua_pushnumber(L, (tm_gettime() - tm_getstart(tm))/1000.0); #endif return lua_gettop(L) - top; } /*-------------------------------------------------------------------------*\ * Receive data from a buffered object -* Input -* buf: buffer structure to be used -* Lua Input: self [pat_1, pat_2 ... pat_n] -* self: socket object -* pat_i: may be one of the following -* "*l": reads a text line, defined as a string of caracters terminates -* by a LF character, preceded or not by a CR character. This is -* the default pattern -* "*lu": reads a text line, terminanted by a CR character only. (Unix mode) -* "*a": reads until connection closed -* number: reads 'number' characters from the socket object -* Lua Returns -* On success: one string for each pattern -* On error: all strings for which there was no error, followed by one -* nil value for the remaining strings, followed by an error code \*-------------------------------------------------------------------------*/ -int buf_receive(lua_State *L, p_buf buf) +int buf_meth_receive(lua_State *L, p_buf buf) { int top = lua_gettop(L); - int arg, err = PRIV_DONE; - p_base base = buf->buf_base; - tm_markstart(&base->base_tm); + int arg, err = IO_DONE; + p_tm tm = buf->tm; + tm_markstart(tm); /* push default pattern if need be */ if (top < 2) { lua_pushstring(L, "*l"); top++; } - /* make sure we have enough stack space */ + /* make sure we have enough stack space for all returns */ luaL_checkstack(L, top+LUA_MINSTACK, "too many arguments"); /* receive all patterns */ - for (arg = 2; arg <= top && err == PRIV_DONE; arg++) { + for (arg = 2; arg <= top && err == IO_DONE; arg++) { if (!lua_isnumber(L, arg)) { - static cchar *patternnames[] = {"*l", "*lu", "*a", "*w", NULL}; - cchar *pattern = luaL_optstring(L, arg, NULL); + static const char *patternnames[] = {"*l", "*a", NULL}; + const char *pattern = lua_isnil(L, arg) ? + "*l" : luaL_checkstring(L, arg); /* get next pattern */ switch (luaL_findstring(pattern, patternnames)) { - case 0: /* DOS line pattern */ - err = recvdosline(L, buf); break; - case 1: /* Unix line pattern */ - err = recvunixline(L, buf); break; - case 2: /* Until closed pattern */ - err = recvall(L, buf); break; - case 3: /* Word pattern */ - luaL_argcheck(L, 0, arg, "word patterns are deprecated"); + case 0: /* line pattern */ + err = recvline(L, buf); break; + case 1: /* until closed pattern */ + err = recvall(L, buf); + if (err == IO_CLOSED) err = IO_DONE; break; default: /* else it is an error */ luaL_argcheck(L, 0, arg, "invalid receive pattern"); @@ -140,25 +107,20 @@ int buf_receive(lua_State *L, p_buf buf) /* push nil for each pattern after an error */ for ( ; arg <= top; arg++) lua_pushnil(L); /* last return is an error code */ - priv_pusherror(L, err); + error_push(L, err); #ifdef LUASOCKET_DEBUG /* push time elapsed during operation as the last return value */ - lua_pushnumber(L, tm_getelapsed(&base->base_tm)/1000.0); + lua_pushnumber(L, (tm_gettime() - tm_getstart(tm))/1000.0); #endif return lua_gettop(L) - top; } /*-------------------------------------------------------------------------*\ * Determines if there is any data in the read buffer -* Input -* buf: buffer structure to be used -* Returns -* 1 if empty, 0 if there is data \*-------------------------------------------------------------------------*/ -int buf_isempty(lua_State *L, p_buf buf) +int buf_isempty(p_buf buf) { - (void) L; - return buf->buf_first >= buf->buf_last; + return buf->first >= buf->last; } /*=========================================================================*\ @@ -166,24 +128,16 @@ int buf_isempty(lua_State *L, p_buf buf) \*=========================================================================*/ /*-------------------------------------------------------------------------*\ * Sends a raw block of data through a buffered object. -* Input -* buf: buffer structure to be used -* data: data to be sent -* len: number of bytes to send -* Output -* sent: number of bytes sent -* Returns -* operation error code. \*-------------------------------------------------------------------------*/ -static int sendraw(lua_State *L, p_buf buf, cchar *data, size_t len, - size_t *sent) +static int sendraw(p_buf buf, const char *data, size_t count, size_t *sent) { - p_base base = buf->buf_base; + p_io io = buf->io; + p_tm tm = buf->tm; size_t total = 0; - int err = PRIV_DONE; - while (total < len && err == PRIV_DONE) { + int err = IO_DONE; + while (total < count && err == IO_DONE) { size_t done; - err = base->base_send(L, base, data + total, len - total, &done); + err = io->send(io->ctx, data+total, count-total, &done, tm_get(tm)); total += done; } *sent = total; @@ -192,25 +146,21 @@ static int sendraw(lua_State *L, p_buf buf, cchar *data, size_t len, /*-------------------------------------------------------------------------*\ * Reads a raw block of data from a buffered object. -* Input -* buf: buffer structure -* wanted: number of bytes to be read -* Returns -* operation error code. \*-------------------------------------------------------------------------*/ -static int recvraw(lua_State *L, p_buf buf, size_t wanted) +static +int recvraw(lua_State *L, p_buf buf, size_t wanted) { - int err = PRIV_DONE; + int err = IO_DONE; size_t total = 0; luaL_Buffer b; luaL_buffinit(L, &b); - while (total < wanted && err == PRIV_DONE) { - size_t len; cchar *data; - err = buf_contents(L, buf, &data, &len); - len = MIN(len, wanted - total); - luaL_addlstring(&b, data, len); - buf_skip(L, buf, len); - total += len; + while (total < wanted && err == IO_DONE) { + size_t count; const char *data; + err = buf_get(buf, &data, &count); + count = MIN(count, wanted - total); + luaL_addlstring(&b, data, count); + buf_skip(buf, count); + total += count; } luaL_pushresult(&b); return err; @@ -218,21 +168,18 @@ static int recvraw(lua_State *L, p_buf buf, size_t wanted) /*-------------------------------------------------------------------------*\ * Reads everything until the connection is closed -* Input -* buf: buffer structure -* Result -* operation error code. \*-------------------------------------------------------------------------*/ -static int recvall(lua_State *L, p_buf buf) +static +int recvall(lua_State *L, p_buf buf) { - int err = PRIV_DONE; + int err = IO_DONE; luaL_Buffer b; luaL_buffinit(L, &b); - while (err == PRIV_DONE) { - cchar *data; size_t len; - err = buf_contents(L, buf, &data, &len); - luaL_addlstring(&b, data, len); - buf_skip(L, buf, len); + while (err == IO_DONE) { + const char *data; size_t count; + err = buf_get(buf, &data, &count); + luaL_addlstring(&b, data, count); + buf_skip(buf, count); } luaL_pushresult(&b); return err; @@ -241,61 +188,27 @@ static int recvall(lua_State *L, p_buf buf) /*-------------------------------------------------------------------------*\ * Reads a line terminated by a CR LF pair or just by a LF. The CR and LF * are not returned by the function and are discarded from the buffer. -* Input -* buf: buffer structure -* Result -* operation error code. PRIV_DONE, PRIV_TIMEOUT or PRIV_CLOSED \*-------------------------------------------------------------------------*/ -static int recvdosline(lua_State *L, p_buf buf) +static +int recvline(lua_State *L, p_buf buf) { int err = 0; luaL_Buffer b; luaL_buffinit(L, &b); - while (err == PRIV_DONE) { - size_t len, pos; cchar *data; - err = buf_contents(L, buf, &data, &len); + while (err == IO_DONE) { + size_t count, pos; const char *data; + err = buf_get(buf, &data, &count); pos = 0; - while (pos < len && data[pos] != '\n') { + while (pos < count && data[pos] != '\n') { /* we ignore all \r's */ if (data[pos] != '\r') luaL_putchar(&b, data[pos]); pos++; } - if (pos < len) { /* found '\n' */ - buf_skip(L, buf, pos+1); /* skip '\n' too */ + if (pos < count) { /* found '\n' */ + buf_skip(buf, pos+1); /* skip '\n' too */ break; /* we are done */ } else /* reached the end of the buffer */ - buf_skip(L, buf, pos); - } - luaL_pushresult(&b); - return err; -} - -/*-------------------------------------------------------------------------*\ -* Reads a line terminated by a LF character, which is not returned by -* the function, and is skipped in the buffer. -* Input -* buf: buffer structure -* Returns -* operation error code. PRIV_DONE, PRIV_TIMEOUT or PRIV_CLOSED -\*-------------------------------------------------------------------------*/ -static int recvunixline(lua_State *L, p_buf buf) -{ - int err = PRIV_DONE; - luaL_Buffer b; - luaL_buffinit(L, &b); - while (err == 0) { - size_t pos, len; cchar *data; - err = buf_contents(L, buf, &data, &len); - pos = 0; - while (pos < len && data[pos] != '\n') { - luaL_putchar(&b, data[pos]); - pos++; - } - if (pos < len) { /* found '\n' */ - buf_skip(L, buf, pos+1); /* skip '\n' too */ - break; /* we are done */ - } else /* reached the end of the buffer */ - buf_skip(L, buf, pos); + buf_skip(buf, pos); } luaL_pushresult(&b); return err; @@ -303,38 +216,32 @@ static int recvunixline(lua_State *L, p_buf buf) /*-------------------------------------------------------------------------*\ * Skips a given number of bytes in read buffer -* Input -* buf: buffer structure -* len: number of bytes to skip \*-------------------------------------------------------------------------*/ -static void buf_skip(lua_State *L, p_buf buf, size_t len) +static +void buf_skip(p_buf buf, size_t count) { - buf->buf_first += len; - if (buf_isempty(L, buf)) buf->buf_first = buf->buf_last = 0; + buf->first += count; + if (buf_isempty(buf)) + buf->first = buf->last = 0; } /*-------------------------------------------------------------------------*\ * Return any data available in buffer, or get more data from transport layer * if buffer is empty. -* Input -* buf: buffer structure -* Output -* data: pointer to buffer start -* len: buffer buffer length -* Returns -* PRIV_DONE, PRIV_CLOSED, PRIV_TIMEOUT ... \*-------------------------------------------------------------------------*/ -static int buf_contents(lua_State *L, p_buf buf, cchar **data, size_t *len) +static +int buf_get(p_buf buf, const char **data, size_t *count) { - int err = PRIV_DONE; - p_base base = buf->buf_base; - if (buf_isempty(L, buf)) { - size_t done; - err = base->base_receive(L, base, buf->buf_data, BUF_SIZE, &done); - buf->buf_first = 0; - buf->buf_last = done; + int err = IO_DONE; + p_io io = buf->io; + p_tm tm = buf->tm; + if (buf_isempty(buf)) { + size_t got; + err = io->recv(io->ctx, buf->data, BUF_SIZE, &got, tm_get(tm)); + buf->first = 0; + buf->last = got; } - *len = buf->buf_last - buf->buf_first; - *data = buf->buf_data + buf->buf_first; + *count = buf->last - buf->first; + *data = buf->data + buf->first; return err; } diff --git a/src/buffer.h b/src/buffer.h index 4943e3b..3ffc145 100644 --- a/src/buffer.h +++ b/src/buffer.h @@ -3,11 +3,12 @@ * * RCS ID: $Id$ \*=========================================================================*/ -#ifndef BUF_H_ -#define BUF_H_ +#ifndef BUF_H +#define BUF_H #include -#include "lsbase.h" +#include "io.h" +#include "tm.h" /* buffer size in bytes */ #define BUF_SIZE 8192 @@ -15,10 +16,11 @@ /*-------------------------------------------------------------------------*\ * Buffer control structure \*-------------------------------------------------------------------------*/ -typedef struct t_buf_tag { - size_t buf_first, buf_last; - char buf_data[BUF_SIZE]; - p_base buf_base; +typedef struct t_buf_ { + p_io io; /* IO driver used for this buffer */ + p_tm tm; /* timeout management for this buffer */ + size_t first, last; /* index of first and last bytes of stored data */ + char data[BUF_SIZE]; /* storage space for buffer data */ } t_buf; typedef t_buf *p_buf; @@ -26,9 +28,9 @@ typedef t_buf *p_buf; * Exported functions \*-------------------------------------------------------------------------*/ void buf_open(lua_State *L); -void buf_init(lua_State *L, p_buf buf, p_base base); -int buf_send(lua_State *L, p_buf buf); -int buf_receive(lua_State *L, p_buf buf); -int buf_isempty(lua_State *L, p_buf buf); +void buf_init(p_buf buf, p_io io, p_tm tm); +int buf_meth_send(lua_State *L, p_buf buf); +int buf_meth_receive(lua_State *L, p_buf buf); +int buf_isempty(p_buf buf); -#endif /* BUF_H_ */ +#endif /* BUF_H */ diff --git a/src/ftp.lua b/src/ftp.lua index 4017eb5..c48f2c7 100644 --- a/src/ftp.lua +++ b/src/ftp.lua @@ -7,7 +7,8 @@ ----------------------------------------------------------------------------- local Public, Private = {}, {} -socket.ftp = Public +local socket = _G[LUASOCKET_LIBNAME] -- get LuaSocket namespace +socket.ftp = Public -- create ftp sub namespace ----------------------------------------------------------------------------- -- Program constants @@ -22,6 +23,33 @@ Public.EMAIL = "anonymous@anonymous.org" -- block size used in transfers Public.BLOCKSIZE = 8192 +----------------------------------------------------------------------------- +-- Tries to get a pattern from the server and closes socket on error +-- sock: socket connected to the server +-- pattern: pattern to receive +-- Returns +-- received pattern on success +-- nil followed by error message on error +----------------------------------------------------------------------------- +function Private.try_receive(sock, pattern) + local data, err = sock:receive(pattern) + if not data then sock:close() end + return data, err +end + +----------------------------------------------------------------------------- +-- Tries to send data to the server and closes socket on error +-- sock: socket connected to the server +-- data: data to send +-- Returns +-- err: error message if any, nil if successfull +----------------------------------------------------------------------------- +function Private.try_send(sock, data) + local sent, err = sock:send(data) + if not sent then sock:close() end + return err +end + ----------------------------------------------------------------------------- -- Tries to send DOS mode lines. Closes socket on error. -- Input @@ -31,24 +59,7 @@ Public.BLOCKSIZE = 8192 -- err: message in case of error, nil if successfull ----------------------------------------------------------------------------- function Private.try_sendline(sock, line) - local err = sock:send(line .. "\r\n") - if err then sock:close() end - return err -end - ------------------------------------------------------------------------------ --- Tries to get a pattern from the server and closes socket on error --- sock: socket connected to the server --- ...: pattern to receive --- Returns --- ...: received pattern --- err: error message if any ------------------------------------------------------------------------------ -function Private.try_receive(...) - local sock = arg[1] - local data, err = sock.receive(unpack(arg)) - if err then sock:close() end - return data, err + return Private.try_send(sock, line .. "\r\n") end ----------------------------------------------------------------------------- @@ -307,20 +318,20 @@ end -- nil if successfull, or an error message in case of error ----------------------------------------------------------------------------- function Private.send_indirect(data, send_cb, chunk, size) - local sent, err - sent = 0 + local total, sent, err + total = 0 while 1 do if type(chunk) ~= "string" or type(size) ~= "number" then data:close() if not chunk and type(size) == "string" then return size else return "invalid callback return" end end - err = data:send(chunk) + sent, err = data:send(chunk) if err then data:close() return err end - sent = sent + string.len(chunk) + total = total + sent if sent >= size then break end chunk, size = send_cb() end diff --git a/src/http.lua b/src/http.lua index 59645ee..d531a2f 100644 --- a/src/http.lua +++ b/src/http.lua @@ -7,7 +7,8 @@ ----------------------------------------------------------------------------- local Public, Private = {}, {} -socket.http = Public +local socket = _G[LUASOCKET_LIBNAME] -- get LuaSocket namespace +socket.http = Public -- create http sub namespace ----------------------------------------------------------------------------- -- Program constants @@ -24,19 +25,15 @@ Public.BLOCKSIZE = 8192 ----------------------------------------------------------------------------- -- Tries to get a pattern from the server and closes socket on error -- sock: socket connected to the server --- ...: pattern to receive +-- pattern: pattern to receive -- Returns --- ...: received pattern --- err: error message if any +-- received pattern on success +-- nil followed by error message on error ----------------------------------------------------------------------------- -function Private.try_receive(...) - local sock = arg[1] - local data, err = sock.receive(unpack(arg)) - if err then - sock:close() - return nil, err - end - return data +function Private.try_receive(sock, pattern) + local data, err = sock:receive(pattern) + if not data then sock:close() end + return data, err end ----------------------------------------------------------------------------- @@ -47,8 +44,8 @@ end -- err: error message if any, nil if successfull ----------------------------------------------------------------------------- function Private.try_send(sock, data) - local err = sock:send(data) - if err then sock:close() end + local sent, err = sock:send(data) + if not sent then sock:close() end return err end @@ -285,21 +282,21 @@ end -- nil if successfull, or an error message in case of error ----------------------------------------------------------------------------- function Private.send_indirect(data, send_cb, chunk, size) - local sent, err - sent = 0 + local total, sent, err + total = 0 while 1 do if type(chunk) ~= "string" or type(size) ~= "number" then data:close() if not chunk and type(size) == "string" then return size else return "invalid callback return" end end - err = data:send(chunk) + sent, err = data:send(chunk) if err then data:close() return err end - sent = sent + string.len(chunk) - if sent >= size then break end + total = total + sent + if total >= size then break end chunk, size = send_cb() end end diff --git a/src/inet.c b/src/inet.c index 341c60e..f20762f 100644 --- a/src/inet.c +++ b/src/inet.c @@ -1,12 +1,5 @@ /*=========================================================================*\ -* Internet domain class: inherits from the Socket class, and implement -* a few methods shared by all internet related objects -* Lua methods: -* getpeername: gets socket peer ip address and port -* getsockname: gets local socket ip address and port -* Global Lua fuctions: -* toip: gets resolver info on host name -* tohostname: gets resolver info on dotted-quad +* Internet domain functions * * RCS ID: $Id$ \*=========================================================================*/ @@ -15,23 +8,27 @@ #include #include -#include "lsinet.h" -#include "lssock.h" -#include "lscompat.h" +#include "luasocket.h" +#include "inet.h" /*=========================================================================*\ * Internal function prototypes. \*=========================================================================*/ -static int inet_lua_toip(lua_State *L); -static int inet_lua_tohostname(lua_State *L); -static int inet_lua_getpeername(lua_State *L); -static int inet_lua_getsockname(lua_State *L); +static int inet_global_toip(lua_State *L); +static int inet_global_tohostname(lua_State *L); + static void inet_pushresolved(lua_State *L, struct hostent *hp); -#ifdef COMPAT_INETATON -static int inet_aton(cchar *cp, struct in_addr *inp); +#ifdef INET_ATON +static int inet_aton(const char *cp, struct in_addr *inp); #endif +static luaL_reg func[] = { + { "toip", inet_global_toip }, + { "tohostname", inet_global_tohostname }, + { NULL, NULL} +}; + /*=========================================================================*\ * Exported functions \*=========================================================================*/ @@ -40,39 +37,7 @@ static int inet_aton(cchar *cp, struct in_addr *inp); \*-------------------------------------------------------------------------*/ void inet_open(lua_State *L) { - lua_pushcfunction(L, inet_lua_toip); - priv_newglobal(L, "toip"); - lua_pushcfunction(L, inet_lua_tohostname); - priv_newglobal(L, "tohostname"); - priv_newglobalmethod(L, "getsockname"); - priv_newglobalmethod(L, "getpeername"); -} - -/*-------------------------------------------------------------------------*\ -* Hook lua methods to methods table. -* Input -* lsclass: class name -\*-------------------------------------------------------------------------*/ -void inet_inherit(lua_State *L, cchar *lsclass) -{ - unsigned int i; - static struct luaL_reg funcs[] = { - {"getsockname", inet_lua_getsockname}, - {"getpeername", inet_lua_getpeername}, - }; - sock_inherit(L, lsclass); - for (i = 0; i < sizeof(funcs)/sizeof(funcs[0]); i++) { - lua_pushcfunction(L, funcs[i].func); - priv_setmethod(L, lsclass, funcs[i].name); - } -} - -/*-------------------------------------------------------------------------*\ -* Constructs the object -\*-------------------------------------------------------------------------*/ -void inet_construct(lua_State *L, p_inet inet) -{ - sock_construct(L, (p_sock) inet); + luaL_openlib(L, LUASOCKET_LIBNAME, func, 0); } /*=========================================================================*\ @@ -87,17 +52,18 @@ void inet_construct(lua_State *L, p_inet inet) * On success: first IP address followed by a resolved table * On error: nil, followed by an error message \*-------------------------------------------------------------------------*/ -static int inet_lua_toip(lua_State *L) +static int inet_global_toip(lua_State *L) { - cchar *address = luaL_checkstring(L, 1); + const char *address = luaL_checkstring(L, 1); struct in_addr addr; struct hostent *hp; if (inet_aton(address, &addr)) hp = gethostbyaddr((char *) &addr, sizeof(addr), AF_INET); - else hp = gethostbyname(address); + else + hp = gethostbyname(address); if (!hp) { lua_pushnil(L); - lua_pushstring(L, compat_hoststrerror()); + lua_pushstring(L, sock_hoststrerror()); return 2; } addr = *((struct in_addr *) hp->h_addr); @@ -115,17 +81,18 @@ static int inet_lua_toip(lua_State *L) * On success: canonic name followed by a resolved table * On error: nil, followed by an error message \*-------------------------------------------------------------------------*/ -static int inet_lua_tohostname(lua_State *L) +static int inet_global_tohostname(lua_State *L) { - cchar *address = luaL_checkstring(L, 1); + const char *address = luaL_checkstring(L, 1); struct in_addr addr; struct hostent *hp; if (inet_aton(address, &addr)) hp = gethostbyaddr((char *) &addr, sizeof(addr), AF_INET); - else hp = gethostbyname(address); + else + hp = gethostbyname(address); if (!hp) { lua_pushnil(L); - lua_pushstring(L, compat_hoststrerror()); + lua_pushstring(L, sock_hoststrerror()); return 2; } lua_pushstring(L, hp->h_name); @@ -138,18 +105,17 @@ static int inet_lua_tohostname(lua_State *L) \*=========================================================================*/ /*-------------------------------------------------------------------------*\ * Retrieves socket peer name -* Lua Input: sock +* Input: * sock: socket * Lua Returns * On success: ip address and port of peer * On error: nil \*-------------------------------------------------------------------------*/ -static int inet_lua_getpeername(lua_State *L) +int inet_meth_getpeername(lua_State *L, p_sock ps) { - p_sock sock = (p_sock) lua_touserdata(L, 1); struct sockaddr_in peer; size_t peer_len = sizeof(peer); - if (getpeername(sock->fd, (SA *) &peer, &peer_len) < 0) { + if (getpeername(*ps, (SA *) &peer, &peer_len) < 0) { lua_pushnil(L); return 1; } @@ -160,18 +126,17 @@ static int inet_lua_getpeername(lua_State *L) /*-------------------------------------------------------------------------*\ * Retrieves socket local name -* Lua Input: sock +* Input: * sock: socket * Lua Returns * On success: local ip address and port * On error: nil \*-------------------------------------------------------------------------*/ -static int inet_lua_getsockname(lua_State *L) +int inet_meth_getsockname(lua_State *L, p_sock ps) { - p_sock sock = (p_sock) lua_touserdata(L, 1); struct sockaddr_in local; size_t local_len = sizeof(local); - if (getsockname(sock->fd, (SA *) &local, &local_len) < 0) { + if (getsockname(*ps, (SA *) &local, &local_len) < 0) { lua_pushnil(L); return 1; } @@ -222,47 +187,53 @@ static void inet_pushresolved(lua_State *L, struct hostent *hp) } /*-------------------------------------------------------------------------*\ -* Tries to create a TCP socket and connect to remote address (address, port) +* Tries to connect to remote address (address, port) * Input -* client: socket structure to be used +* ps: pointer to socket * address: host name or ip address * port: port number to bind to * Returns * NULL in case of success, error message otherwise \*-------------------------------------------------------------------------*/ -cchar *inet_tryconnect(p_inet inet, cchar *address, ushort port) +const char *inet_tryconnect(p_sock ps, const char *address, ushort port) { struct sockaddr_in remote; memset(&remote, 0, sizeof(remote)); remote.sin_family = AF_INET; remote.sin_port = htons(port); - if (!strlen(address) || !inet_aton(address, &remote.sin_addr)) { - struct hostent *hp = gethostbyname(address); - struct in_addr **addr; - if (!hp) return compat_hoststrerror(); - addr = (struct in_addr **) hp->h_addr_list; - memcpy(&remote.sin_addr, *addr, sizeof(struct in_addr)); - } - compat_setblocking(inet->fd); - if (compat_connect(inet->fd, (SA *) &remote, sizeof(remote)) < 0) { - const char *err = compat_connectstrerror(); - compat_close(inet->fd); - inet->fd = COMPAT_INVALIDFD; + if (strcmp(address, "*")) { + if (!strlen(address) || !inet_aton(address, &remote.sin_addr)) { + struct hostent *hp = gethostbyname(address); + struct in_addr **addr; + remote.sin_family = AF_INET; + if (!hp) return sock_hoststrerror(); + addr = (struct in_addr **) hp->h_addr_list; + memcpy(&remote.sin_addr, *addr, sizeof(struct in_addr)); + } + } else remote.sin_family = AF_UNSPEC; + sock_setblocking(ps); + const char *err = sock_connect(ps, (SA *) &remote, sizeof(remote)); + if (err) { + sock_destroy(ps); + *ps = SOCK_INVALID; return err; - } - compat_setnonblocking(inet->fd); - return NULL; + } else { + sock_setnonblocking(ps); + return NULL; + } } /*-------------------------------------------------------------------------*\ -* Tries to create a TCP socket and bind it to (address, port) +* Tries to bind socket to (address, port) * Input +* sock: pointer to socket * address: host name or ip address * port: port number to bind to * Returns * NULL in case of success, error message otherwise \*-------------------------------------------------------------------------*/ -cchar *inet_trybind(p_inet inet, cchar *address, ushort port) +const char *inet_trybind(p_sock ps, const char *address, ushort port, + int backlog) { struct sockaddr_in local; memset(&local, 0, sizeof(local)); @@ -274,34 +245,33 @@ cchar *inet_trybind(p_inet inet, cchar *address, ushort port) (!strlen(address) || !inet_aton(address, &local.sin_addr))) { struct hostent *hp = gethostbyname(address); struct in_addr **addr; - if (!hp) return compat_hoststrerror(); + if (!hp) return sock_hoststrerror(); addr = (struct in_addr **) hp->h_addr_list; memcpy(&local.sin_addr, *addr, sizeof(struct in_addr)); } - compat_setblocking(inet->fd); - if (compat_bind(inet->fd, (SA *) &local, sizeof(local)) < 0) { - const char *err = compat_bindstrerror(); - compat_close(inet->fd); - inet->fd = COMPAT_INVALIDFD; + sock_setblocking(ps); + const char *err = sock_bind(ps, (SA *) &local, sizeof(local)); + if (err) { + sock_destroy(ps); + *ps = SOCK_INVALID; return err; + } else { + sock_setnonblocking(ps); + if (backlog > 0) sock_listen(ps, backlog); + return NULL; } - compat_setnonblocking(inet->fd); - return NULL; } /*-------------------------------------------------------------------------*\ * Tries to create a new inet socket * Input -* udp: udp structure +* sock: pointer to socket * Returns * NULL if successfull, error message on error \*-------------------------------------------------------------------------*/ -cchar *inet_trysocket(p_inet inet, int type) +const char *inet_trycreate(p_sock ps, int type) { - if (inet->fd != COMPAT_INVALIDFD) compat_close(inet->fd); - inet->fd = compat_socket(AF_INET, type, 0); - if (inet->fd == COMPAT_INVALIDFD) return compat_socketstrerror(); - else return NULL; + return sock_create(ps, AF_INET, type, 0); } /*-------------------------------------------------------------------------*\ diff --git a/src/inet.h b/src/inet.h index 93fcedf..bcefc5b 100644 --- a/src/inet.h +++ b/src/inet.h @@ -1,38 +1,26 @@ /*=========================================================================*\ -* Internet domain class: inherits from the Socket class, and implement -* a few methods shared by all internet related objects +* Internet domain functions * * RCS ID: $Id$ \*=========================================================================*/ -#ifndef INET_H_ -#define INET_H_ +#ifndef INET_H +#define INET_H #include -#include "lssock.h" - -/* class name */ -#define INET_CLASS "luasocket(inet)" - -/*-------------------------------------------------------------------------*\ -* Socket fields -\*-------------------------------------------------------------------------*/ -#define INET_FIELDS SOCK_FIELDS - -/*-------------------------------------------------------------------------*\ -* Socket structure -\*-------------------------------------------------------------------------*/ -typedef t_sock t_inet; -typedef t_inet *p_inet; +#include "sock.h" /*-------------------------------------------------------------------------*\ * Exported functions \*-------------------------------------------------------------------------*/ void inet_open(lua_State *L); -void inet_construct(lua_State *L, p_inet inet); -void inet_inherit(lua_State *L, cchar *lsclass); -cchar *inet_tryconnect(p_sock sock, cchar *address, ushort); -cchar *inet_trybind(p_sock sock, cchar *address, ushort); -cchar *inet_trysocket(p_inet inet, int type); +const char *inet_tryconnect(p_sock ps, const char *address, + unsigned short port); +const char *inet_trybind(p_sock ps, const char *address, + unsigned short port, int backlog); +const char *inet_trycreate(p_sock ps, int type); + +int inet_meth_getpeername(lua_State *L, p_sock ps); +int inet_meth_getsockname(lua_State *L, p_sock ps); #endif /* INET_H_ */ diff --git a/src/io.c b/src/io.c new file mode 100644 index 0000000..902124a --- /dev/null +++ b/src/io.c @@ -0,0 +1,8 @@ +#include "io.h" + +void io_init(p_io io, p_send send, p_recv recv, void *ctx) +{ + io->send = send; + io->recv = recv; + io->ctx = ctx; +} diff --git a/src/io.h b/src/io.h new file mode 100644 index 0000000..b5b7f1d --- /dev/null +++ b/src/io.h @@ -0,0 +1,34 @@ +#ifndef IO_H +#define IO_H + +#include "error.h" + +/* interface to send function */ +typedef int (*p_send) ( + void *ctx, /* context needed by send */ + const char *data, /* pointer to buffer with data to send */ + size_t count, /* number of bytes to send from buffer */ + size_t *sent, /* number of bytes sent uppon return */ + int timeout /* number of miliseconds left for transmission */ +); + +/* interface to recv function */ +typedef int (*p_recv) ( + void *ctx, /* context needed by recv */ + char *data, /* pointer to buffer where data will be writen */ + size_t count, /* number of bytes to receive into buffer */ + size_t *got, /* number of bytes received uppon return */ + int timeout /* number of miliseconds left for transmission */ +); + +/* IO driver definition */ +typedef struct t_io_ { + void *ctx; /* context needed by send/recv */ + p_send send; /* send function pointer */ + p_recv recv; /* receive function pointer */ +} t_io; +typedef t_io *p_io; + +void io_init(p_io io, p_send send, p_recv recv, void *ctx); + +#endif /* IO_H */ diff --git a/src/luasocket.c b/src/luasocket.c index bcc705f..53f8c21 100644 --- a/src/luasocket.c +++ b/src/luasocket.c @@ -23,18 +23,13 @@ * LuaSocket includes \*=========================================================================*/ #include "luasocket.h" -#include "lspriv.h" -#include "lsselect.h" -#include "lscompat.h" -#include "lsbase.h" -#include "lstm.h" -#include "lsbuf.h" -#include "lssock.h" -#include "lsinet.h" -#include "lstcpc.h" -#include "lstcps.h" -#include "lstcps.h" -#include "lsudp.h" + +#include "tm.h" +#include "buf.h" +#include "sock.h" +#include "inet.h" +#include "tcp.h" +#include "udp.h" /*=========================================================================*\ * Exported functions @@ -42,34 +37,29 @@ /*-------------------------------------------------------------------------*\ * Initializes all library modules. \*-------------------------------------------------------------------------*/ -LUASOCKET_API int lua_socketlibopen(lua_State *L) +LUASOCKET_API int luaopen_socketlib(lua_State *L) { - compat_open(L); - priv_open(L); - select_open(L); - base_open(L); - tm_open(L); - fd_open(L); - sock_open(L); - inet_open(L); - tcpc_open(L); - buf_open(L); - tcps_open(L); - udp_open(L); -#ifdef LUASOCKET_DOFILE - lua_dofile(L, "concat.lua"); - lua_dofile(L, "code.lua"); - lua_dofile(L, "url.lua"); - lua_dofile(L, "http.lua"); - lua_dofile(L, "smtp.lua"); - lua_dofile(L, "ftp.lua"); -#else -#include "concat.loh" -#include "code.loh" -#include "url.loh" -#include "http.loh" -#include "smtp.loh" -#include "ftp.loh" + /* create namespace table */ + lua_pushstring(L, LUASOCKET_LIBNAME); + lua_newtable(L); +#ifdef LUASOCKET_DEBUG + lua_pushstring(L, "debug"); + lua_pushnumber(L, 1); + lua_settable(L, -3); #endif + lua_settable(L, LUA_GLOBALSINDEX); + /* make sure modules know what is our namespace */ + lua_pushstring(L, "LUASOCKET_LIBNAME"); + lua_pushstring(L, LUASOCKET_LIBNAME); + lua_settable(L, LUA_GLOBALSINDEX); + /* initialize all modules */ + sock_open(L); + tm_open(L); + buf_open(L); + inet_open(L); + tcp_open(L); + udp_open(L); + /* load all Lua code */ + lua_dofile(L, "luasocket.lua"); return 0; } diff --git a/src/luasocket.h b/src/luasocket.h index fd22606..6c25af2 100644 --- a/src/luasocket.h +++ b/src/luasocket.h @@ -5,8 +5,8 @@ * * RCS ID: $Id$ \*=========================================================================*/ -#ifndef _LUASOCKET_H_ -#define _LUASOCKET_H_ +#ifndef LUASOCKET_H +#define LUASOCKET_H /*-------------------------------------------------------------------------*\ * Current luasocket version @@ -28,6 +28,6 @@ /*-------------------------------------------------------------------------*\ * Initializes the library. \*-------------------------------------------------------------------------*/ -LUASOCKET_API int lua_socketlibopen(lua_State *L); +LUASOCKET_API int luaopen_socketlib(lua_State *L); -#endif /* _LUASOCKET_H_ */ +#endif /* LUASOCKET_H */ diff --git a/src/mbox.lua b/src/mbox.lua index 4a72331..f52719b 100644 --- a/src/mbox.lua +++ b/src/mbox.lua @@ -5,10 +5,10 @@ mbox = Public function Public.split_message(message_s) local message = {} message_s = string.gsub(message_s, "\r\n", "\n") - string.gsub(message_s, "^(.-\n)\n", function (h) %message.headers = h end) - string.gsub(message_s, "^.-\n\n(.*)", function (b) %message.body = b end) + string.gsub(message_s, "^(.-\n)\n", function (h) message.headers = h end) + string.gsub(message_s, "^.-\n\n(.*)", function (b) message.body = b end) if not message.body then - string.gsub(message_s, "^\n(.*)", function (b) %message.body = b end) + string.gsub(message_s, "^\n(.*)", function (b) message.body = b end) end if not message.headers and not message.body then message.headers = message_s @@ -20,7 +20,7 @@ function Public.split_headers(headers_s) local headers = {} headers_s = string.gsub(headers_s, "\r\n", "\n") headers_s = string.gsub(headers_s, "\n[ ]+", " ") - string.gsub("\n" .. headers_s, "\n([^\n]+)", function (h) table.insert(%headers, h) end) + string.gsub("\n" .. headers_s, "\n([^\n]+)", function (h) table.insert(headers, h) end) return headers end @@ -32,10 +32,10 @@ function Public.parse_header(header_s) end function Public.parse_headers(headers_s) - local headers_t = %Public.split_headers(headers_s) + local headers_t = Public.split_headers(headers_s) local headers = {} for i = 1, table.getn(headers_t) do - local name, value = %Public.parse_header(headers_t[i]) + local name, value = Public.parse_header(headers_t[i]) if name then name = string.lower(name) if headers[name] then @@ -73,16 +73,16 @@ function Public.split_mbox(mbox_s) end function Public.parse(mbox_s) - local mbox = %Public.split_mbox(mbox_s) + local mbox = Public.split_mbox(mbox_s) for i = 1, table.getn(mbox) do - mbox[i] = %Public.parse_message(mbox[i]) + mbox[i] = Public.parse_message(mbox[i]) end return mbox end function Public.parse_message(message_s) local message = {} - message.headers, message.body = %Public.split_message(message_s) - message.headers = %Public.parse_headers(message.headers) + message.headers, message.body = Public.split_message(message_s) + message.headers = Public.parse_headers(message.headers) return message end diff --git a/src/smtp.lua b/src/smtp.lua index 0ba2b0f..604f79b 100644 --- a/src/smtp.lua +++ b/src/smtp.lua @@ -7,7 +7,8 @@ ----------------------------------------------------------------------------- local Public, Private = {}, {} -socket.smtp = Public +local socket = _G[LUASOCKET_LIBNAME] -- get LuaSocket namespace +socket.smtp = Public -- create smtp sub namespace ----------------------------------------------------------------------------- -- Program constants @@ -23,32 +24,30 @@ Public.DOMAIN = os.getenv("SERVER_NAME") or "localhost" Public.SERVER = "localhost" ----------------------------------------------------------------------------- --- Tries to send data through socket. Closes socket on error. --- Input --- sock: server socket --- data: string to be sent +-- Tries to get a pattern from the server and closes socket on error +-- sock: socket connected to the server +-- pattern: pattern to receive -- Returns --- err: message in case of error, nil if successfull +-- received pattern on success +-- nil followed by error message on error ----------------------------------------------------------------------------- -function Private.try_send(sock, data) - local err = sock:send(data) - if err then sock:close() end - return err +function Private.try_receive(sock, pattern) + local data, err = sock:receive(pattern) + if not data then sock:close() end + return data, err end ----------------------------------------------------------------------------- --- Tries to get a pattern from the server and closes socket on error --- sock: socket opened to the server --- ...: pattern to receive +-- Tries to send data to the server and closes socket on error +-- sock: socket connected to the server +-- data: data to send -- Returns --- ...: received pattern --- err: error message if any +-- err: error message if any, nil if successfull ----------------------------------------------------------------------------- -function Private.try_receive(...) - local sock = arg[1] - local data, err = sock.receive(unpack(arg)) - if err then sock:close() end - return data, err +function Private.try_send(sock, data) + local sent, err = sock:send(data) + if not sent then sock:close() end + return err end ----------------------------------------------------------------------------- diff --git a/src/tcp.c b/src/tcp.c new file mode 100644 index 0000000..db6a38e --- /dev/null +++ b/src/tcp.c @@ -0,0 +1,222 @@ +/*=========================================================================*\ +* TCP object +* +* RCS ID: $Id$ +\*=========================================================================*/ +#include + +#include +#include + +#include "luasocket.h" + +#include "aux.h" +#include "inet.h" +#include "tcp.h" + +/*=========================================================================*\ +* Internal function prototypes +\*=========================================================================*/ +static int tcp_global_create(lua_State *L); +static int tcp_meth_connect(lua_State *L); +static int tcp_meth_bind(lua_State *L); +static int tcp_meth_send(lua_State *L); +static int tcp_meth_getsockname(lua_State *L); +static int tcp_meth_getpeername(lua_State *L); +static int tcp_meth_receive(lua_State *L); +static int tcp_meth_accept(lua_State *L); +static int tcp_meth_close(lua_State *L); +static int tcp_meth_timeout(lua_State *L); + +/* tcp object methods */ +static luaL_reg tcp[] = { + {"connect", tcp_meth_connect}, + {"send", tcp_meth_send}, + {"receive", tcp_meth_receive}, + {"bind", tcp_meth_bind}, + {"accept", tcp_meth_accept}, + {"setpeername", tcp_meth_connect}, + {"setsockname", tcp_meth_bind}, + {"getpeername", tcp_meth_getpeername}, + {"getsockname", tcp_meth_getsockname}, + {"timeout", tcp_meth_timeout}, + {"close", tcp_meth_close}, + {NULL, NULL} +}; + +/* functions in library namespace */ +static luaL_reg func[] = { + {"tcp", tcp_global_create}, + {NULL, NULL} +}; + +/*-------------------------------------------------------------------------*\ +* Initializes module +\*-------------------------------------------------------------------------*/ +void tcp_open(lua_State *L) +{ + /* create classes */ + aux_newclass(L, "tcp{master}", tcp); + aux_newclass(L, "tcp{client}", tcp); + aux_newclass(L, "tcp{server}", tcp); + /* create class groups */ + aux_add2group(L, "tcp{client}", "tcp{client, server}"); + aux_add2group(L, "tcp{server}", "tcp{client, server}"); + aux_add2group(L, "tcp{master}", "tcp{any}"); + aux_add2group(L, "tcp{client}", "tcp{any}"); + aux_add2group(L, "tcp{server}", "tcp{any}"); + /* define library functions */ + luaL_openlib(L, LUASOCKET_LIBNAME, func, 0); + lua_pop(L, 1); +} + +/*=========================================================================*\ +* Lua methods +\*=========================================================================*/ +/*-------------------------------------------------------------------------*\ +* Just call buffered IO methods +\*-------------------------------------------------------------------------*/ +static int tcp_meth_send(lua_State *L) +{ + p_tcp tcp = (p_tcp) aux_checkclass(L, "tcp{client}", 1); + return buf_meth_send(L, &tcp->buf); +} + +static int tcp_meth_receive(lua_State *L) +{ + p_tcp tcp = (p_tcp) aux_checkclass(L, "tcp{client}", 1); + return buf_meth_receive(L, &tcp->buf); +} + +/*-------------------------------------------------------------------------*\ +* Just call inet methods +\*-------------------------------------------------------------------------*/ +static int tcp_meth_getpeername(lua_State *L) +{ + p_tcp tcp = (p_tcp) aux_checkclass(L, "tcp{client}", 1); + return inet_meth_getpeername(L, &tcp->sock); +} + +static int tcp_meth_getsockname(lua_State *L) +{ + p_tcp tcp = (p_tcp) aux_checkgroup(L, "tcp{client, server}", 1); + return inet_meth_getsockname(L, &tcp->sock); +} + +/*-------------------------------------------------------------------------*\ +* Just call tm methods +\*-------------------------------------------------------------------------*/ +static int tcp_meth_timeout(lua_State *L) +{ + p_tcp tcp = (p_tcp) aux_checkgroup(L, "tcp{any}", 1); + return tm_meth_timeout(L, &tcp->tm); +} + +/*-------------------------------------------------------------------------*\ +* Closes socket used by object +\*-------------------------------------------------------------------------*/ +static int tcp_meth_close(lua_State *L) +{ + p_tcp tcp = (p_tcp) aux_checkgroup(L, "tcp{any}", 1); + sock_destroy(&tcp->sock); + return 0; +} + +/*-------------------------------------------------------------------------*\ +* Turns a master tcp object into a client object. +\*-------------------------------------------------------------------------*/ +static int tcp_meth_connect(lua_State *L) +{ + p_tcp tcp = (p_tcp) aux_checkclass(L, "tcp{master}", 1); + const char *address = luaL_checkstring(L, 2); + unsigned short port = (ushort) luaL_checknumber(L, 3); + const char *err = inet_tryconnect(&tcp->sock, address, port); + if (err) { + lua_pushnil(L); + lua_pushstring(L, err); + return 2; + } + /* turn master object into a client object */ + aux_setclass(L, "tcp{client}", 1); + lua_pushnumber(L, 1); + return 1; +} + +/*-------------------------------------------------------------------------*\ +* Turns a master object into a server object +\*-------------------------------------------------------------------------*/ +static int tcp_meth_bind(lua_State *L) +{ + p_tcp tcp = (p_tcp) aux_checkclass(L, "tcp{master}", 1); + const char *address = luaL_checkstring(L, 2); + unsigned short port = (ushort) luaL_checknumber(L, 3); + int backlog = (int) luaL_optnumber(L, 4, 1); + const char *err = inet_trybind(&tcp->sock, address, port, backlog); + if (err) { + lua_pushnil(L); + lua_pushstring(L, err); + return 2; + } + /* turn master object into a server object */ + aux_setclass(L, "tcp{server}", 1); + lua_pushnumber(L, 1); + return 1; +} + +/*-------------------------------------------------------------------------*\ +* Waits for and returns a client object attempting connection to the +* server object +\*-------------------------------------------------------------------------*/ +static int tcp_meth_accept(lua_State *L) +{ + struct sockaddr_in addr; + size_t addr_len = sizeof(addr); + p_tcp server = (p_tcp) aux_checkclass(L, "tcp{server}", 1); + p_tm tm = &server->tm; + p_tcp client = lua_newuserdata(L, sizeof(t_tcp)); + tm_markstart(tm); + aux_setclass(L, "tcp{client}", -1); + for ( ;; ) { + sock_accept(&server->sock, &client->sock, + (SA *) &addr, &addr_len, tm_get(tm)); + if (client->sock == SOCK_INVALID) { + if (tm_get(tm) == 0) { + lua_pushnil(L); + error_push(L, IO_TIMEOUT); + return 2; + } + } else break; + } + /* initialize remaining structure fields */ + io_init(&client->io, (p_send) sock_send, (p_recv) sock_recv, &client->sock); + tm_init(&client->tm, -1, -1); + buf_init(&client->buf, &client->io, &client->tm); + return 1; +} + +/*=========================================================================*\ +* Library functions +\*=========================================================================*/ +/*-------------------------------------------------------------------------*\ +* Creates a master tcp object +\*-------------------------------------------------------------------------*/ +int tcp_global_create(lua_State *L) +{ + /* allocate tcp object */ + p_tcp tcp = (p_tcp) lua_newuserdata(L, sizeof(t_tcp)); + /* set its type as master object */ + aux_setclass(L, "tcp{master}", -1); + /* try to allocate a system socket */ + const char *err = inet_trycreate(&tcp->sock, SOCK_STREAM); + if (err) { /* get rid of object on stack and push error */ + lua_pop(L, 1); + lua_pushnil(L); + lua_pushstring(L, err); + return 2; + } + /* initialize remaining structure fields */ + io_init(&tcp->io, (p_send) sock_send, (p_recv) sock_recv, &tcp->sock); + tm_init(&tcp->tm, -1, -1); + buf_init(&tcp->buf, &tcp->io, &tcp->tm); + return 1; +} diff --git a/src/tcp.h b/src/tcp.h new file mode 100644 index 0000000..d4cc65c --- /dev/null +++ b/src/tcp.h @@ -0,0 +1,20 @@ +#ifndef TCP_H +#define TCP_H + +#include + +#include "buf.h" +#include "tm.h" +#include "sock.h" + +typedef struct t_tcp_ { + t_sock sock; + t_io io; + t_buf buf; + t_tm tm; +} t_tcp; +typedef t_tcp *p_tcp; + +void tcp_open(lua_State *L); + +#endif diff --git a/src/timeout.c b/src/timeout.c index 5549c89..17878aa 100644 --- a/src/timeout.c +++ b/src/timeout.c @@ -1,18 +1,19 @@ /*=========================================================================*\ * Timeout management functions * Global Lua functions: -* _sleep: (debug mode only) -* _time: (debug mode only) +* _sleep +* _time * * RCS ID: $Id$ \*=========================================================================*/ +#include + #include #include -#include "lspriv.h" -#include "lstm.h" - -#include +#include "luasocket.h" +#include "aux.h" +#include "tm.h" #ifdef WIN32 #include @@ -28,78 +29,69 @@ static int tm_lua_time(lua_State *L); static int tm_lua_sleep(lua_State *L); +static luaL_reg func[] = { + { "time", tm_lua_time }, + { "sleep", tm_lua_sleep }, + { NULL, NULL } +}; + /*=========================================================================*\ * Exported functions. \*=========================================================================*/ /*-------------------------------------------------------------------------*\ -* Sets timeout limits -* Input -* tm: timeout control structure -* mode: block or return timeout -* value: timeout value in miliseconds +* Initialize structure \*-------------------------------------------------------------------------*/ -void tm_set(p_tm tm, int tm_block, int tm_return) +void tm_init(p_tm tm, int block, int total) { - tm->tm_block = tm_block; - tm->tm_return = tm_return; + tm->block = block; + tm->total = total; } /*-------------------------------------------------------------------------*\ -* Returns timeout limits -* Input -* tm: timeout control structure -* mode: block or return timeout -* value: timeout value in miliseconds +* Set and get timeout limits \*-------------------------------------------------------------------------*/ -void tm_get(p_tm tm, int *tm_block, int *tm_return) -{ - if (tm_block) *tm_block = tm->tm_block; - if (tm_return) *tm_return = tm->tm_return; -} +void tm_setblock(p_tm tm, int block) +{ tm->block = block; } +void tm_settotal(p_tm tm, int total) +{ tm->total = total; } +int tm_getblock(p_tm tm) +{ return tm->block; } +int tm_gettotal(p_tm tm) +{ return tm->total; } +int tm_getstart(p_tm tm) +{ return tm->start; } /*-------------------------------------------------------------------------*\ -* Determines how much time we have left for the current io operation -* an IO write operation. +* Determines how much time we have left for the current operation * Input * tm: timeout control structure * Returns * the number of ms left or -1 if there is no time limit \*-------------------------------------------------------------------------*/ -int tm_getremaining(p_tm tm) +int tm_get(p_tm tm) { /* no timeout */ - if (tm->tm_block < 0 && tm->tm_return < 0) + if (tm->block < 0 && tm->total < 0) return -1; /* there is no block timeout, we use the return timeout */ - else if (tm->tm_block < 0) - return MAX(tm->tm_return - tm_gettime() + tm->tm_start, 0); + else if (tm->block < 0) + return MAX(tm->total - tm_gettime() + tm->start, 0); /* there is no return timeout, we use the block timeout */ - else if (tm->tm_return < 0) - return tm->tm_block; + else if (tm->total < 0) + return tm->block; /* both timeouts are specified */ - else return MIN(tm->tm_block, - MAX(tm->tm_return - tm_gettime() + tm->tm_start, 0)); + else return MIN(tm->block, + MAX(tm->total - tm_gettime() + tm->start, 0)); } /*-------------------------------------------------------------------------*\ -* Marks the operation start time in sock structure +* Marks the operation start time in structure * Input * tm: timeout control structure \*-------------------------------------------------------------------------*/ void tm_markstart(p_tm tm) { - tm->tm_start = tm_gettime(); - tm->tm_end = tm->tm_start; -} - -/*-------------------------------------------------------------------------*\ -* Returns the length of the operation in ms -* Input -* tm: timeout control structure -\*-------------------------------------------------------------------------*/ -int tm_getelapsed(p_tm tm) -{ - return tm->tm_end - tm->tm_start; + tm->start = tm_gettime(); } /*-------------------------------------------------------------------------*\ @@ -125,11 +117,31 @@ int tm_gettime(void) \*-------------------------------------------------------------------------*/ void tm_open(lua_State *L) { - (void) L; - lua_pushcfunction(L, tm_lua_time); - priv_newglobal(L, "_time"); - lua_pushcfunction(L, tm_lua_sleep); - priv_newglobal(L, "_sleep"); + luaL_openlib(L, LUASOCKET_LIBNAME, func, 0); +} + +/*-------------------------------------------------------------------------*\ +* Sets timeout values for IO operations +* Lua Input: base, time [, mode] +* time: time out value in seconds +* mode: "b" for block timeout, "t" for total timeout. (default: b) +\*-------------------------------------------------------------------------*/ +int tm_meth_timeout(lua_State *L, p_tm tm) +{ + int ms = lua_isnil(L, 2) ? -1 : (int) (luaL_checknumber(L, 2)*1000.0); + const char *mode = luaL_optstring(L, 3, "b"); + switch (*mode) { + case 'b': + tm_setblock(tm, ms); + break; + case 'r': case 't': + tm_settotal(tm, ms); + break; + default: + luaL_argcheck(L, 0, 3, "invalid timeout mode"); + break; + } + return 0; } /*=========================================================================*\ diff --git a/src/timeout.h b/src/timeout.h index 1dc0a5a..43476cb 100644 --- a/src/timeout.h +++ b/src/timeout.h @@ -3,23 +3,29 @@ * * RCS ID: $Id$ \*=========================================================================*/ -#ifndef _TM_H -#define _TM_H +#ifndef TM_H +#define TM_H -typedef struct t_tm_tag { - int tm_return; - int tm_block; - int tm_start; - int tm_end; +#include + +/* timeout control structure */ +typedef struct t_tm_ { + int total; /* total number of miliseconds for operation */ + int block; /* maximum time for blocking calls */ + int start; /* time of start of operation */ } t_tm; typedef t_tm *p_tm; -void tm_set(p_tm tm, int tm_block, int tm_return); -int tm_getremaining(p_tm tm); -int tm_getelapsed(p_tm tm); -int tm_gettime(void); -void tm_get(p_tm tm, int *tm_block, int *tm_return); -void tm_markstart(p_tm tm); void tm_open(lua_State *L); +void tm_init(p_tm tm, int block, int total); +void tm_setblock(p_tm tm, int block); +void tm_settotal(p_tm tm, int total); +int tm_getblock(p_tm tm); +int tm_gettotal(p_tm tm); +void tm_markstart(p_tm tm); +int tm_getstart(p_tm tm); +int tm_get(p_tm tm); +int tm_gettime(void); +int tm_meth_timeout(lua_State *L, p_tm tm); #endif diff --git a/src/udp.c b/src/udp.c index 361816c..1701d1b 100644 --- a/src/udp.c +++ b/src/udp.c @@ -1,299 +1,263 @@ /*=========================================================================*\ -* UDP class: inherits from Socked and Internet domain classes and provides -* all the functionality for UDP objects. -* Lua methods: -* send: using compat module -* sendto: using compat module -* receive: using compat module -* receivefrom: using compat module -* setpeername: using internet module -* setsockname: using internet module -* Global Lua functions: -* udp: creates the udp object +* UDP object * * RCS ID: $Id$ \*=========================================================================*/ -#include +#include #include #include -#include "lsinet.h" -#include "lsudp.h" -#include "lscompat.h" -#include "lsselect.h" +#include "luasocket.h" + +#include "aux.h" +#include "inet.h" +#include "udp.h" /*=========================================================================*\ -* Internal function prototypes. +* Internal function prototypes \*=========================================================================*/ -static int udp_lua_send(lua_State *L); -static int udp_lua_sendto(lua_State *L); -static int udp_lua_receive(lua_State *L); -static int udp_lua_receivefrom(lua_State *L); -static int udp_lua_setpeername(lua_State *L); -static int udp_lua_setsockname(lua_State *L); +static int udp_global_create(lua_State *L); +static int udp_meth_send(lua_State *L); +static int udp_meth_sendto(lua_State *L); +static int udp_meth_receive(lua_State *L); +static int udp_meth_receivefrom(lua_State *L); +static int udp_meth_getsockname(lua_State *L); +static int udp_meth_getpeername(lua_State *L); +static int udp_meth_setsockname(lua_State *L); +static int udp_meth_setpeername(lua_State *L); +static int udp_meth_close(lua_State *L); +static int udp_meth_timeout(lua_State *L); -static int udp_global_udp(lua_State *L); - -static struct luaL_reg funcs[] = { - {"send", udp_lua_send}, - {"sendto", udp_lua_sendto}, - {"receive", udp_lua_receive}, - {"receivefrom", udp_lua_receivefrom}, - {"setpeername", udp_lua_setpeername}, - {"setsockname", udp_lua_setsockname}, +/* udp object methods */ +static luaL_reg udp[] = { + {"setpeername", udp_meth_setpeername}, + {"setsockname", udp_meth_setsockname}, + {"getsockname", udp_meth_getsockname}, + {"getpeername", udp_meth_getpeername}, + {"send", udp_meth_send}, + {"sendto", udp_meth_sendto}, + {"receive", udp_meth_receive}, + {"receivefrom", udp_meth_receivefrom}, + {"timeout", udp_meth_timeout}, + {"close", udp_meth_close}, + {NULL, NULL} +}; + +/* functions in library namespace */ +static luaL_reg func[] = { + {"udp", udp_global_create}, + {NULL, NULL} }; -/*=========================================================================*\ -* Exported functions -\*=========================================================================*/ /*-------------------------------------------------------------------------*\ * Initializes module \*-------------------------------------------------------------------------*/ void udp_open(lua_State *L) { - unsigned int i; - priv_newclass(L, UDP_CLASS); - udp_inherit(L, UDP_CLASS); - /* declare global functions */ - lua_pushcfunction(L, udp_global_udp); - priv_newglobal(L, "udp"); - for (i = 0; i < sizeof(funcs)/sizeof(funcs[0]); i++) - priv_newglobalmethod(L, funcs[i].name); - /* make class selectable */ - select_addclass(L, UDP_CLASS); -} - -/*-------------------------------------------------------------------------*\ -* Hook object methods to methods table. -\*-------------------------------------------------------------------------*/ -void udp_inherit(lua_State *L, cchar *lsclass) -{ - unsigned int i; - inet_inherit(L, lsclass); - for (i = 0; i < sizeof(funcs)/sizeof(funcs[0]); i++) { - lua_pushcfunction(L, funcs[i].func); - priv_setmethod(L, lsclass, funcs[i].name); - } -} - -/*-------------------------------------------------------------------------*\ -* Initializes socket structure -\*-------------------------------------------------------------------------*/ -void udp_construct(lua_State *L, p_udp udp) -{ - inet_construct(L, (p_inet) udp); - udp->udp_connected = 0; -} - -/*-------------------------------------------------------------------------*\ -* Creates a socket structure and initializes it. A socket object is -* left in the Lua stack. -* Returns -* pointer to allocated structure -\*-------------------------------------------------------------------------*/ -p_udp udp_push(lua_State *L) -{ - p_udp udp = (p_udp) lua_newuserdata(L, sizeof(t_udp)); - priv_setclass(L, UDP_CLASS); - udp_construct(L, udp); - return udp; + /* create classes */ + aux_newclass(L, "udp{connected}", udp); + aux_newclass(L, "udp{unconnected}", udp); + /* create class groups */ + aux_add2group(L, "udp{connected}", "udp{any}"); + aux_add2group(L, "udp{unconnected}", "udp{any}"); + /* define library functions */ + luaL_openlib(L, LUASOCKET_LIBNAME, func, 0); + lua_pop(L, 1); } /*=========================================================================*\ -* Socket table constructors +* Lua methods \*=========================================================================*/ /*-------------------------------------------------------------------------*\ -* Creates a udp socket object and returns it to the Lua script. -* Lua Input: [options] -* options: socket options table -* Lua Returns -* On success: udp socket -* On error: nil, followed by an error message +* Send data through connected udp socket \*-------------------------------------------------------------------------*/ -static int udp_global_udp(lua_State *L) +static int udp_meth_send(lua_State *L) { - int oldtop = lua_gettop(L); - p_udp udp = udp_push(L); - cchar *err = inet_trysocket((p_inet) udp, SOCK_DGRAM); - if (err) { - lua_pushnil(L); - lua_pushstring(L, err); - return 2; - } - if (oldtop < 1) return 1; - err = compat_trysetoptions(L, udp->fd); - if (err) { - lua_pushnil(L); - lua_pushstring(L, err); - return 2; - } - return 1; -} - -/*=========================================================================*\ -* Socket table methods -\*=========================================================================*/ -/*-------------------------------------------------------------------------*\ -* Receives data from a UDP socket -* Lua Input: sock [, wanted] -* sock: client socket created by the connect function -* wanted: the number of bytes expected (default: LUASOCKET_UDPBUFFERSIZE) -* Lua Returns -* On success: datagram received -* On error: nil, followed by an error message -\*-------------------------------------------------------------------------*/ -static int udp_lua_receive(lua_State *L) -{ - p_udp udp = (p_udp) lua_touserdata(L, 1); - char buffer[UDP_DATAGRAMSIZE]; - size_t got, wanted = (size_t) luaL_optnumber(L, 2, sizeof(buffer)); + p_udp udp = (p_udp) aux_checkclass(L, "udp{connected}", 1); + p_tm tm = &udp->tm; + size_t count, sent = 0; int err; - p_tm tm = &udp->base_tm; - wanted = MIN(wanted, sizeof(buffer)); + const char *data = luaL_checklstring(L, 2, &count); tm_markstart(tm); - err = compat_recv(udp->fd, buffer, wanted, &got, tm_getremaining(tm)); - if (err == PRIV_CLOSED) err = PRIV_REFUSED; - if (err != PRIV_DONE) lua_pushnil(L); - else lua_pushlstring(L, buffer, got); - priv_pusherror(L, err); + err = sock_send(&udp->sock, data, count, &sent, tm_get(tm)); + if (err == IO_DONE) lua_pushnumber(L, sent); + else lua_pushnil(L); + error_push(L, err); return 2; } /*-------------------------------------------------------------------------*\ -* Receives a datagram from a UDP socket -* Lua Input: sock [, wanted] -* sock: client socket created by the connect function -* wanted: the number of bytes expected (default: LUASOCKET_UDPBUFFERSIZE) -* Lua Returns -* On success: datagram received, ip and port of sender -* On error: nil, followed by an error message +* Send data through unconnected udp socket \*-------------------------------------------------------------------------*/ -static int udp_lua_receivefrom(lua_State *L) +static int udp_meth_sendto(lua_State *L) { - p_udp udp = (p_udp) lua_touserdata(L, 1); - p_tm tm = &udp->base_tm; - struct sockaddr_in peer; - size_t peer_len = sizeof(peer); - char buffer[UDP_DATAGRAMSIZE]; - size_t wanted = (size_t) luaL_optnumber(L, 2, sizeof(buffer)); - size_t got; + p_udp udp = (p_udp) aux_checkclass(L, "udp{unconnected}", 1); + size_t count, sent = 0; + const char *data = luaL_checklstring(L, 2, &count); + const char *ip = luaL_checkstring(L, 3); + ushort port = (ushort) luaL_checknumber(L, 4); + p_tm tm = &udp->tm; + struct sockaddr_in addr; int err; - if (udp->udp_connected) luaL_error(L, "receivefrom on connected socket"); + memset(&addr, 0, sizeof(addr)); + if (!inet_aton(ip, &addr.sin_addr)) + luaL_argerror(L, 3, "invalid ip address"); + addr.sin_family = AF_INET; + addr.sin_port = htons(port); tm_markstart(tm); - wanted = MIN(wanted, sizeof(buffer)); - err = compat_recvfrom(udp->fd, buffer, wanted, &got, tm_getremaining(tm), - (SA *) &peer, &peer_len); - if (err == PRIV_CLOSED) err = PRIV_REFUSED; - if (err == PRIV_DONE) { + err = sock_sendto(&udp->sock, data, count, &sent, + (SA *) &addr, sizeof(addr), tm_get(tm)); + if (err == IO_DONE) lua_pushnumber(L, sent); + else lua_pushnil(L); + error_push(L, err == IO_CLOSED ? IO_REFUSED : err); + return 2; +} + +/*-------------------------------------------------------------------------*\ +* Receives data from a UDP socket +\*-------------------------------------------------------------------------*/ +static int udp_meth_receive(lua_State *L) +{ + p_udp udp = (p_udp) aux_checkgroup(L, "udp{any}", 1); + char buffer[UDP_DATAGRAMSIZE]; + size_t got, count = (size_t) luaL_optnumber(L, 2, sizeof(buffer)); + int err; + p_tm tm = &udp->tm; + count = MIN(count, sizeof(buffer)); + tm_markstart(tm); + err = sock_recv(&udp->sock, buffer, count, &got, tm_get(tm)); + if (err == IO_DONE) lua_pushlstring(L, buffer, got); + else lua_pushnil(L); + error_push(L, err); + return 2; +} + +/*-------------------------------------------------------------------------*\ +* Receives data and sender from a UDP socket +\*-------------------------------------------------------------------------*/ +static int udp_meth_receivefrom(lua_State *L) +{ + p_udp udp = (p_udp) aux_checkclass(L, "udp{unconnected}", 1); + struct sockaddr_in addr; + size_t addr_len = sizeof(addr); + char buffer[UDP_DATAGRAMSIZE]; + size_t got, count = (size_t) luaL_optnumber(L, 2, sizeof(buffer)); + int err; + p_tm tm = &udp->tm; + tm_markstart(tm); + count = MIN(count, sizeof(buffer)); + err = sock_recvfrom(&udp->sock, buffer, count, &got, + (SA *) &addr, &addr_len, tm_get(tm)); + if (err == IO_DONE) { lua_pushlstring(L, buffer, got); - lua_pushstring(L, inet_ntoa(peer.sin_addr)); - lua_pushnumber(L, ntohs(peer.sin_port)); + lua_pushstring(L, inet_ntoa(addr.sin_addr)); + lua_pushnumber(L, ntohs(addr.sin_port)); return 3; } else { lua_pushnil(L); - priv_pusherror(L, err); + error_push(L, err); return 2; } } /*-------------------------------------------------------------------------*\ -* Send data through a connected UDP socket -* Lua Input: sock, data -* sock: udp socket -* data: data to be sent -* Lua Returns -* On success: nil, followed by the total number of bytes sent -* On error: error message +* Just call inet methods \*-------------------------------------------------------------------------*/ -static int udp_lua_send(lua_State *L) +static int udp_meth_getpeername(lua_State *L) { - p_udp udp = (p_udp) lua_touserdata(L, 1); - p_tm tm = &udp->base_tm; - size_t wanted, sent = 0; - int err; - cchar *data = luaL_checklstring(L, 2, &wanted); - if (!udp->udp_connected) luaL_error(L, "send on unconnected socket"); - tm_markstart(tm); - err = compat_send(udp->fd, data, wanted, &sent, tm_getremaining(tm)); - priv_pusherror(L, err == PRIV_CLOSED ? PRIV_REFUSED : err); - lua_pushnumber(L, sent); - return 2; + p_udp udp = (p_udp) aux_checkclass(L, "udp{connected}", 1); + return inet_meth_getpeername(L, &udp->sock); +} + +static int udp_meth_getsockname(lua_State *L) +{ + p_udp udp = (p_udp) aux_checkgroup(L, "udp{any}", 1); + return inet_meth_getsockname(L, &udp->sock); } /*-------------------------------------------------------------------------*\ -* Send data through a unconnected UDP socket -* Lua Input: sock, data, ip, port -* sock: udp socket -* data: data to be sent -* ip: ip address of target -* port: port in target -* Lua Returns -* On success: nil, followed by the total number of bytes sent -* On error: error message +* Just call tm methods \*-------------------------------------------------------------------------*/ -static int udp_lua_sendto(lua_State *L) +static int udp_meth_timeout(lua_State *L) { - p_udp udp = (p_udp) lua_touserdata(L, 1); - size_t wanted, sent = 0; - cchar *data = luaL_checklstring(L, 2, &wanted); - cchar *ip = luaL_checkstring(L, 3); - ushort port = (ushort) luaL_checknumber(L, 4); - p_tm tm = &udp->base_tm; - struct sockaddr_in peer; - int err; - if (udp->udp_connected) luaL_error(L, "sendto on connected socket"); - memset(&peer, 0, sizeof(peer)); - if (!inet_aton(ip, &peer.sin_addr)) luaL_error(L, "invalid ip address"); - peer.sin_family = AF_INET; - peer.sin_port = htons(port); - tm_markstart(tm); - err = compat_sendto(udp->fd, data, wanted, &sent, tm_getremaining(tm), - (SA *) &peer, sizeof(peer)); - priv_pusherror(L, err == PRIV_CLOSED ? PRIV_REFUSED : err); - lua_pushnumber(L, sent); - return 2; + p_udp udp = (p_udp) aux_checkgroup(L, "udp{any}", 1); + return tm_meth_timeout(L, &udp->tm); } /*-------------------------------------------------------------------------*\ -* Associates a local address to an UDP socket -* Lua Input: address, port -* address: host name or ip address to bind to -* port: port to bind to -* Lua Returns -* On success: nil -* On error: error message +* Turns a master udp object into a client object. \*-------------------------------------------------------------------------*/ -static int udp_lua_setsockname(lua_State * L) +static int udp_meth_setpeername(lua_State *L) { - p_udp udp = (p_udp) lua_touserdata(L, 1); - cchar *address = luaL_checkstring(L, 2); - ushort port = (ushort) luaL_checknumber(L, 3); - cchar *err = inet_trybind((p_inet) udp, address, port); - if (err) lua_pushstring(L, err); - else lua_pushnil(L); - return 1; -} - -/*-------------------------------------------------------------------------*\ -* Sets a peer for a UDP socket -* Lua Input: address, port -* address: remote host name -* port: remote host port -* Lua Returns -* On success: nil -* On error: error message -\*-------------------------------------------------------------------------*/ -static int udp_lua_setpeername(lua_State *L) -{ - p_udp udp = (p_udp) lua_touserdata(L, 1); - cchar *address = luaL_checkstring(L, 2); - ushort port = (ushort) luaL_checknumber(L, 3); - cchar *err = inet_tryconnect((p_inet) udp, address, port); - if (!err) { - udp->udp_connected = 1; + p_udp udp = (p_udp) aux_checkclass(L, "udp{unconnected}", 1); + const char *address = luaL_checkstring(L, 2); + int connecting = strcmp(address, "*"); + unsigned short port = connecting ? + (ushort) luaL_checknumber(L, 3) : (ushort) luaL_optnumber(L, 3, 0); + const char *err = inet_tryconnect(&udp->sock, address, port); + if (err) { lua_pushnil(L); - } else lua_pushstring(L, err); + lua_pushstring(L, err); + return 2; + } + /* change class to connected or unconnected depending on address */ + if (connecting) aux_setclass(L, "udp{connected}", 1); + else aux_setclass(L, "udp{unconnected}", 1); + lua_pushnumber(L, 1); return 1; } +/*-------------------------------------------------------------------------*\ +* Closes socket used by object +\*-------------------------------------------------------------------------*/ +static int udp_meth_close(lua_State *L) +{ + p_udp udp = (p_udp) aux_checkgroup(L, "udp{any}", 1); + sock_destroy(&udp->sock); + return 0; +} + +/*-------------------------------------------------------------------------*\ +* Turns a master object into a server object +\*-------------------------------------------------------------------------*/ +static int udp_meth_setsockname(lua_State *L) +{ + p_udp udp = (p_udp) aux_checkclass(L, "udp{master}", 1); + const char *address = luaL_checkstring(L, 2); + unsigned short port = (ushort) luaL_checknumber(L, 3); + const char *err = inet_trybind(&udp->sock, address, port, -1); + if (err) { + lua_pushnil(L); + lua_pushstring(L, err); + return 2; + } + lua_pushnumber(L, 1); + return 1; +} + +/*=========================================================================*\ +* Library functions +\*=========================================================================*/ +/*-------------------------------------------------------------------------*\ +* Creates a master udp object +\*-------------------------------------------------------------------------*/ +int udp_global_create(lua_State *L) +{ + /* allocate udp object */ + p_udp udp = (p_udp) lua_newuserdata(L, sizeof(t_udp)); + /* set its type as master object */ + aux_setclass(L, "udp{unconnected}", -1); + /* try to allocate a system socket */ + const char *err = inet_trycreate(&udp->sock, SOCK_DGRAM); + if (err) { + /* get rid of object on stack and push error */ + lua_pop(L, 1); + lua_pushnil(L); + lua_pushstring(L, err); + return 2; + } + /* initialize timeout management */ + tm_init(&udp->tm, -1, -1); + return 1; +} diff --git a/src/udp.h b/src/udp.h index 928a99f..4ba53e6 100644 --- a/src/udp.h +++ b/src/udp.h @@ -1,30 +1,19 @@ -/*=========================================================================*\ -* UDP class: inherits from Socked and Internet domain classes and provides -* all the functionality for UDP objects. -* -* RCS ID: $Id$ -\*=========================================================================*/ -#ifndef UDP_H_ -#define UDP_H_ +#ifndef UDP_H +#define UDP_H -#include "lsinet.h" +#include -#define UDP_CLASS "luasocket(UDP socket)" +#include "tm.h" +#include "sock.h" #define UDP_DATAGRAMSIZE 576 -#define UDP_FIELDS \ - INET_FIELDS; \ - int udp_connected - -typedef struct t_udp_tag { - UDP_FIELDS; +typedef struct t_udp_ { + t_sock sock; + t_tm tm; } t_udp; typedef t_udp *p_udp; -void udp_inherit(lua_State *L, cchar *lsclass); -void udp_construct(lua_State *L, p_udp udp); void udp_open(lua_State *L); -p_udp udp_push(lua_State *L); #endif diff --git a/src/unix.c b/src/usocket.c similarity index 60% rename from src/unix.c rename to src/usocket.c index 23984b0..b4b8d5a 100644 --- a/src/unix.c +++ b/src/usocket.c @@ -1,5 +1,5 @@ /*=========================================================================*\ -* Network compatibilization module: Unix version +* Socket compatibilization module for Unix * * RCS ID: $Id$ \*=========================================================================*/ @@ -7,20 +7,20 @@ #include #include -#include "lscompat.h" +#include "sock.h" /*=========================================================================*\ * Internal function prototypes \*=========================================================================*/ -static cchar *try_setoption(lua_State *L, COMPAT_FD sock); -static cchar *try_setbooloption(lua_State *L, COMPAT_FD sock, int name); +static const char *try_setoption(lua_State *L, p_sock ps); +static const char *try_setbooloption(lua_State *L, p_sock ps, int name); /*=========================================================================*\ * Exported functions. \*=========================================================================*/ -int compat_open(lua_State *L) +int sock_open(lua_State *L) { - /* Instals a handler to ignore sigpipe. */ + /* instals a handler to ignore sigpipe. */ struct sigaction new; memset(&new, 0, sizeof(new)); new.sa_handler = SIG_IGN; @@ -28,143 +28,178 @@ int compat_open(lua_State *L) return 1; } -COMPAT_FD compat_accept(COMPAT_FD s, struct sockaddr *addr, - size_t *len, int deadline) +void sock_destroy(p_sock ps) { - struct timeval tv; - fd_set fds; - tv.tv_sec = deadline / 1000; - tv.tv_usec = (deadline % 1000) * 1000; - FD_ZERO(&fds); - FD_SET(s, &fds); - select(s+1, &fds, NULL, NULL, deadline >= 0 ? &tv : NULL); - return accept(s, addr, len); + close(*ps); } -int compat_send(COMPAT_FD c, cchar *data, size_t count, size_t *sent, - int deadline) +const char *sock_create(p_sock ps, int domain, int type, int protocol) { + t_sock sock = socket(domain, type, protocol); + if (sock == SOCK_INVALID) return sock_createstrerror(); + *ps = sock; + sock_setnonblocking(ps); + sock_setreuseaddr(ps); + return NULL; +} + +const char *sock_connect(p_sock ps, SA *addr, size_t addr_len) +{ + if (connect(*ps, addr, addr_len) < 0) return sock_connectstrerror(); + else return NULL; +} + +const char *sock_bind(p_sock ps, SA *addr, size_t addr_len) +{ + if (bind(*ps, addr, addr_len) < 0) return sock_bindstrerror(); + else return NULL; +} + +void sock_listen(p_sock ps, int backlog) +{ + listen(*ps, backlog); +} + +void sock_accept(p_sock ps, p_sock pa, SA *addr, size_t *addr_len, int timeout) +{ + t_sock sock = *ps; + struct timeval tv; + fd_set fds; + tv.tv_sec = timeout / 1000; + tv.tv_usec = (timeout % 1000) * 1000; + FD_ZERO(&fds); + FD_SET(sock, &fds); + select(sock+1, &fds, NULL, NULL, timeout >= 0 ? &tv : NULL); + *pa = accept(sock, addr, addr_len); +} + +int sock_send(p_sock ps, const char *data, size_t count, size_t *sent, + int timeout) +{ + t_sock sock = *ps; struct timeval tv; fd_set fds; ssize_t put = 0; int err; int ret; - tv.tv_sec = deadline / 1000; - tv.tv_usec = (deadline % 1000) * 1000; + tv.tv_sec = timeout / 1000; + tv.tv_usec = (timeout % 1000) * 1000; FD_ZERO(&fds); - FD_SET(c, &fds); - ret = select(c+1, NULL, &fds, NULL, deadline >= 0 ? &tv : NULL); + FD_SET(sock, &fds); + ret = select(sock+1, NULL, &fds, NULL, timeout >= 0 ? &tv : NULL); if (ret > 0) { - put = write(c, data, count); + put = write(sock, data, count); if (put <= 0) { - err = PRIV_CLOSED; + err = IO_CLOSED; #ifdef __CYGWIN__ /* this is for CYGWIN, which is like Unix but has Win32 bugs */ - if (errno == EWOULDBLOCK) err = PRIV_DONE; + if (errno == EWOULDBLOCK) err = IO_DONE; #endif *sent = 0; } else { *sent = put; - err = PRIV_DONE; + err = IO_DONE; } return err; } else { *sent = 0; - return PRIV_TIMEOUT; + return IO_TIMEOUT; } } -int compat_sendto(COMPAT_FD c, cchar *data, size_t count, size_t *sent, - int deadline, SA *addr, size_t len) +int sock_sendto(p_sock ps, const char *data, size_t count, size_t *sent, + SA *addr, size_t addr_len, int timeout) { + t_sock sock = *ps; struct timeval tv; fd_set fds; ssize_t put = 0; int err; int ret; - tv.tv_sec = deadline / 1000; - tv.tv_usec = (deadline % 1000) * 1000; + tv.tv_sec = timeout / 1000; + tv.tv_usec = (timeout % 1000) * 1000; FD_ZERO(&fds); - FD_SET(c, &fds); - ret = select(c+1, NULL, &fds, NULL, deadline >= 0 ? &tv : NULL); + FD_SET(sock, &fds); + ret = select(sock+1, NULL, &fds, NULL, timeout >= 0 ? &tv : NULL); if (ret > 0) { - put = sendto(c, data, count, 0, addr, len); + put = sendto(sock, data, count, 0, addr, addr_len); if (put <= 0) { - err = PRIV_CLOSED; + err = IO_CLOSED; #ifdef __CYGWIN__ /* this is for CYGWIN, which is like Unix but has Win32 bugs */ - if (sent < 0 && errno == EWOULDBLOCK) err = PRIV_DONE; + if (sent < 0 && errno == EWOULDBLOCK) err = IO_DONE; #endif *sent = 0; } else { *sent = put; - err = PRIV_DONE; + err = IO_DONE; } return err; } else { *sent = 0; - return PRIV_TIMEOUT; + return IO_TIMEOUT; } } -int compat_recv(COMPAT_FD c, char *data, size_t count, size_t *got, - int deadline) +int sock_recv(p_sock ps, char *data, size_t count, size_t *got, int timeout) { + t_sock sock = *ps; struct timeval tv; fd_set fds; int ret; ssize_t taken = 0; - tv.tv_sec = deadline / 1000; - tv.tv_usec = (deadline % 1000) * 1000; + tv.tv_sec = timeout / 1000; + tv.tv_usec = (timeout % 1000) * 1000; FD_ZERO(&fds); - FD_SET(c, &fds); - ret = select(c+1, &fds, NULL, NULL, deadline >= 0 ? &tv : NULL); + FD_SET(sock, &fds); + ret = select(sock+1, &fds, NULL, NULL, timeout >= 0 ? &tv : NULL); if (ret > 0) { - taken = read(c, data, count); + taken = read(sock, data, count); if (taken <= 0) { *got = 0; - return PRIV_CLOSED; + return IO_CLOSED; } else { *got = taken; - return PRIV_DONE; + return IO_DONE; } } else { *got = 0; - return PRIV_TIMEOUT; + return IO_TIMEOUT; } } -int compat_recvfrom(COMPAT_FD c, char *data, size_t count, size_t *got, - int deadline, SA *addr, size_t *len) +int sock_recvfrom(p_sock ps, char *data, size_t count, size_t *got, + SA *addr, size_t *addr_len, int timeout) { + t_sock sock = *ps; struct timeval tv; fd_set fds; int ret; ssize_t taken = 0; - tv.tv_sec = deadline / 1000; - tv.tv_usec = (deadline % 1000) * 1000; + tv.tv_sec = timeout / 1000; + tv.tv_usec = (timeout % 1000) * 1000; FD_ZERO(&fds); - FD_SET(c, &fds); - ret = select(c+1, &fds, NULL, NULL, deadline >= 0 ? &tv : NULL); + FD_SET(sock, &fds); + ret = select(sock+1, &fds, NULL, NULL, timeout >= 0 ? &tv : NULL); if (ret > 0) { - taken = recvfrom(c, data, count, 0, addr, len); + taken = recvfrom(sock, data, count, 0, addr, addr_len); if (taken <= 0) { *got = 0; - return PRIV_CLOSED; + return IO_CLOSED; } else { *got = taken; - return PRIV_DONE; + return IO_DONE; } } else { *got = 0; - return PRIV_TIMEOUT; + return IO_TIMEOUT; } } /*-------------------------------------------------------------------------*\ * Returns a string describing the last host manipulation error. \*-------------------------------------------------------------------------*/ -const char *compat_hoststrerror(void) +const char *sock_hoststrerror(void) { switch (h_errno) { case HOST_NOT_FOUND: return "host not found"; @@ -178,7 +213,7 @@ const char *compat_hoststrerror(void) /*-------------------------------------------------------------------------*\ * Returns a string describing the last socket manipulation error. \*-------------------------------------------------------------------------*/ -const char *compat_socketstrerror(void) +const char *sock_createstrerror(void) { switch (errno) { case EACCES: return "access denied"; @@ -192,7 +227,7 @@ const char *compat_socketstrerror(void) /*-------------------------------------------------------------------------*\ * Returns a string describing the last bind command error. \*-------------------------------------------------------------------------*/ -const char *compat_bindstrerror(void) +const char *sock_bindstrerror(void) { switch (errno) { case EBADF: return "invalid descriptor"; @@ -209,7 +244,7 @@ const char *compat_bindstrerror(void) /*-------------------------------------------------------------------------*\ * Returns a string describing the last connect error. \*-------------------------------------------------------------------------*/ -const char *compat_connectstrerror(void) +const char *sock_connectstrerror(void) { switch (errno) { case EBADF: return "invalid descriptor"; @@ -229,40 +264,30 @@ const char *compat_connectstrerror(void) * Input * sock: socket descriptor \*-------------------------------------------------------------------------*/ -void compat_setreuseaddr(COMPAT_FD sock) +void sock_setreuseaddr(p_sock ps) { int val = 1; - setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, (char *)&val, sizeof(val)); -} - -COMPAT_FD compat_socket(int domain, int type, int protocol) -{ - COMPAT_FD sock = socket(domain, type, protocol); - if (sock != COMPAT_INVALIDFD) { - compat_setnonblocking(sock); - compat_setreuseaddr(sock); - } - return sock; + setsockopt(*ps, SOL_SOCKET, SO_REUSEADDR, (char *)&val, sizeof(val)); } /*-------------------------------------------------------------------------*\ * Put socket into blocking mode. \*-------------------------------------------------------------------------*/ -void compat_setblocking(COMPAT_FD sock) +void sock_setblocking(p_sock ps) { - int flags = fcntl(sock, F_GETFL, 0); + int flags = fcntl(*ps, F_GETFL, 0); flags &= (~(O_NONBLOCK)); - fcntl(sock, F_SETFL, flags); + fcntl(*ps, F_SETFL, flags); } /*-------------------------------------------------------------------------*\ * Put socket into non-blocking mode. \*-------------------------------------------------------------------------*/ -void compat_setnonblocking(COMPAT_FD sock) +void sock_setnonblocking(p_sock ps) { - int flags = fcntl(sock, F_GETFL, 0); + int flags = fcntl(*ps, F_GETFL, 0); flags |= O_NONBLOCK; - fcntl(sock, F_SETFL, flags); + fcntl(*ps, F_SETFL, flags); } /*-------------------------------------------------------------------------*\ @@ -273,54 +298,50 @@ void compat_setnonblocking(COMPAT_FD sock) * Returns * NULL if successfull, error message on error \*-------------------------------------------------------------------------*/ -cchar *compat_trysetoptions(lua_State *L, COMPAT_FD sock) +const char *sock_trysetoptions(lua_State *L, p_sock ps) { if (!lua_istable(L, 1)) luaL_argerror(L, 1, "invalid options table"); lua_pushnil(L); while (lua_next(L, 1)) { - cchar *err = try_setoption(L, sock); + const char *err = try_setoption(L, ps); lua_pop(L, 1); if (err) return err; } return NULL; } -/*=========================================================================*\ -* Internal functions. -\*=========================================================================*/ -static cchar *try_setbooloption(lua_State *L, COMPAT_FD sock, int name) -{ - int bool, res; - if (!lua_isnumber(L, -1)) luaL_error(L, "invalid option value"); - bool = (int) lua_tonumber(L, -1); - res = setsockopt(sock, SOL_SOCKET, name, (char *) &bool, sizeof(bool)); - if (res < 0) return "error setting option"; - else return NULL; -} - - /*-------------------------------------------------------------------------*\ * Set socket options from a table on top of Lua stack. -* Supports SO_KEEPALIVE, SO_DONTROUTE, SO_BROADCAST, and SO_LINGER options. +* Supports SO_KEEPALIVE, SO_DONTROUTE, and SO_BROADCAST options. * Input -* L: Lua state to use -* sock: socket descriptor +* sock: socket * Returns * 1 if successful, 0 otherwise \*-------------------------------------------------------------------------*/ -static cchar *try_setoption(lua_State *L, COMPAT_FD sock) +static const char *try_setoption(lua_State *L, p_sock ps) { - static cchar *options[] = { - "SO_KEEPALIVE", "SO_DONTROUTE", "SO_BROADCAST", "SO_LINGER", NULL + static const char *options[] = { + "SO_KEEPALIVE", "SO_DONTROUTE", "SO_BROADCAST", NULL }; - cchar *option = lua_tostring(L, -2); + const char *option = lua_tostring(L, -2); if (!lua_isstring(L, -2)) return "invalid option"; switch (luaL_findstring(option, options)) { - case 0: return try_setbooloption(L, sock, SO_KEEPALIVE); - case 1: return try_setbooloption(L, sock, SO_DONTROUTE); - case 2: return try_setbooloption(L, sock, SO_BROADCAST); - case 3: return "SO_LINGER is deprecated"; + case 0: return try_setbooloption(L, ps, SO_KEEPALIVE); + case 1: return try_setbooloption(L, ps, SO_DONTROUTE); + case 2: return try_setbooloption(L, ps, SO_BROADCAST); default: return "unsupported option"; } } +/*=========================================================================*\ +* Internal functions. +\*=========================================================================*/ +static const char *try_setbooloption(lua_State *L, p_sock ps, int name) +{ + int bool, res; + if (!lua_isnumber(L, -1)) luaL_error(L, "invalid option value"); + bool = (int) lua_tonumber(L, -1); + res = setsockopt(*ps, SOL_SOCKET, name, (char *) &bool, sizeof(bool)); + if (res < 0) return "error setting option"; + else return NULL; +} diff --git a/src/unix.h b/src/usocket.h similarity index 74% rename from src/unix.h rename to src/usocket.h index 863e478..f124bce 100644 --- a/src/unix.h +++ b/src/usocket.h @@ -1,10 +1,10 @@ /*=========================================================================*\ -* Network compatibilization module: Unix version +* Socket compatibilization module for Unix * * RCS ID: $Id$ \*=========================================================================*/ -#ifndef UNIX_H_ -#define UNIX_H_ +#ifndef UNIX_H +#define UNIX_H /*=========================================================================*\ * BSD include files @@ -31,13 +31,9 @@ #include #include -#define COMPAT_FD int -#define COMPAT_INVALIDFD (-1) +typedef int t_sock; +typedef t_sock *p_sock; -#define compat_bind bind -#define compat_connect connect -#define compat_listen listen -#define compat_close close -#define compat_select select +#define SOCK_INVALID (-1) -#endif /* UNIX_H_ */ +#endif /* UNIX_H */ diff --git a/test/ftptest.lua b/test/ftptest.lua index ee3af91..6ba61a4 100644 --- a/test/ftptest.lua +++ b/test/ftptest.lua @@ -1,5 +1,3 @@ -dofile("noglobals.lua") - local similar = function(s1, s2) return string.lower(string.gsub(s1, "%s", "")) == @@ -34,7 +32,7 @@ end local index, err, saved, back, expected -local t = socket._time() +local t = socket.time() index = readfile("test/index.html") @@ -112,4 +110,4 @@ back, err = socket.ftp.get("ftp://localhost/index.wrong.html;type=a") check(err, err) print("passed all tests") -print(string.format("done in %.2fs", socket._time() - t)) +print(string.format("done in %.2fs", socket.time() - t)) diff --git a/test/httptest.lua b/test/httptest.lua index 1eb4b6a..030974c 100644 --- a/test/httptest.lua +++ b/test/httptest.lua @@ -3,9 +3,6 @@ -- needs ScriptAlias from /home/c/diego/tec/luasocket/test/cgi -- to /luasocket-test-cgi -- needs AllowOverride AuthConfig on /home/c/diego/tec/luasocket/test/auth - -dofile("noglobals.lua") - local similar = function(s1, s2) return string.lower(string.gsub(s1 or "", "%s", "")) == string.lower(string.gsub(s2 or "", "%s", "")) @@ -27,27 +24,27 @@ end local check = function (v, e) if v then print("ok") - else %fail(e) end + else fail(e) end end local check_request = function(request, expect, ignore) local response = socket.http.request(request) for i,v in response do if not ignore[i] then - if v ~= expect[i] then %fail(i .. " differs!") end + if v ~= expect[i] then fail(i .. " differs!") end end end for i,v in expect do if not ignore[i] then - if v ~= response[i] then %fail(i .. " differs!") end + if v ~= response[i] then fail(i .. " differs!") end end end print("ok") end -local request, response, ignore, expect, index, prefix, cgiprefix +local host, request, response, ignore, expect, index, prefix, cgiprefix -local t = socket._time() +local t = socket.time() host = host or "localhost" prefix = prefix or "/luasocket" @@ -310,4 +307,4 @@ check(response and response.headers) print("passed all tests") -print(string.format("done in %.2fs", socket._time() - t)) +print(string.format("done in %.2fs", socket.time() - t)) diff --git a/test/smtptest.lua b/test/smtptest.lua index 27ba400..09bf634 100644 --- a/test/smtptest.lua +++ b/test/smtptest.lua @@ -11,7 +11,7 @@ local files = { "/var/spool/mail/luasock3", } -local t = socket._time() +local t = socket.time() local err dofile("mbox.lua") @@ -106,7 +106,7 @@ local insert = function(sent, message) end local mark = function() - local time = socket._time() + local time = socket.time() return { time = time } end @@ -116,11 +116,11 @@ local wait = function(sentinel, n) while 1 do local mbox = parse(get()) if n == table.getn(mbox) then break end - if socket._time() - sentinel.time > 50 then + if socket.time() - sentinel.time > 50 then to = 1 break end - socket._sleep(1) + socket.sleep(1) io.write(".") io.stdout:flush() end @@ -256,4 +256,4 @@ for i = 1, table.getn(mbox) do end print("passed all tests") -print(string.format("done in %.2fs", socket._time() - t)) +print(string.format("done in %.2fs", socket.time() - t)) diff --git a/test/testclnt.lua b/test/testclnt.lua index 3e80a36..b2b4b18 100644 --- a/test/testclnt.lua +++ b/test/testclnt.lua @@ -43,7 +43,7 @@ function check_timeout(tm, sl, elapsed, err, opp, mode, alldone) else pass("proper timeout") end end else - if mode == "return" then + if mode == "total" then if elapsed > tm then if err ~= "timeout" then fail("should have timed out") else pass("proper timeout") end @@ -66,17 +66,17 @@ function check_timeout(tm, sl, elapsed, err, opp, mode, alldone) end end +if not socket.debug then + fail("Please define LUASOCKET_DEBUG and recompile LuaSocket") +end + io.write("----------------------------------------------\n", "LuaSocket Test Procedures\n", "----------------------------------------------\n") -if not socket._time or not socket._sleep then - fail("not compiled with _DEBUG") -end +start = socket.time() -start = socket._time() - -function tcpreconnect() +function reconnect() io.write("attempting data connection... ") if data then data:close() end remote [[ @@ -87,109 +87,85 @@ function tcpreconnect() if not data then fail(err) else pass("connected!") end end -reconnect = tcpreconnect pass("attempting control connection...") control, err = socket.connect(host, port) if err then fail(err) else pass("connected!") end ------------------------------------------------------------------------- -test("bugs") - -io.write("empty host connect: ") -function empty_connect() - if data then data:close() data = nil end - remote [[ - if data then data:close() data = nil end - data = server:accept() - ]] - data, err = socket.connect("", port) - if not data then - pass("ok") - data = socket.connect(host, port) - else fail("should not have connected!") end -end - -empty_connect() - -io.write("active close: ") -function active_close() - reconnect() - if socket._isclosed(data) then fail("should not be closed") end - data:close() - if not socket._isclosed(data) then fail("should be closed") end - data = nil - local udp = socket.udp() - if socket._isclosed(udp) then fail("should not be closed") end - udp:close() - if not socket._isclosed(udp) then fail("should be closed") end - pass("ok") -end - -active_close() - ------------------------------------------------------------------------ test("method registration") function test_methods(sock, methods) for _, v in methods do if type(sock[v]) ~= "function" then - fail(type(sock) .. " method " .. v .. "not registered") + fail(sock.class .. " method '" .. v .. "' not registered") end end - pass(type(sock) .. " methods are ok") + pass(sock.class .. " methods are ok") end -test_methods(control, { - "close", - "timeout", - "send", - "receive", +test_methods(socket.tcp(), { + "connect", + "send", + "receive", + "bind", + "accept", + "setpeername", + "setsockname", "getpeername", - "getsockname" + "getsockname", + "timeout", + "close", }) -if udpsocket then - test_methods(socket.udp(), { - "close", - "timeout", - "send", - "sendto", - "receive", - "receivefrom", - "getpeername", - "getsockname", - "setsockname", - "setpeername" - }) -end - -test_methods(socket.bind("*", 0), { - "close", +test_methods(socket.udp(), { + "getpeername", + "getsockname", + "setsockname", + "setpeername", + "send", + "sendto", + "receive", + "receivefrom", "timeout", - "accept" + "close", }) ------------------------------------------------------------------------ -test("select function") -function test_selectbugs() - local r, s, e = socket.select(nil, nil, 0.1) - assert(type(r) == "table" and type(s) == "table" and e == "timeout") - pass("both nil: ok") - local udp = socket.udp() - udp:close() - r, s, e = socket.select({ udp }, { udp }, 0.1) - assert(type(r) == "table" and type(s) == "table" and e == "timeout") - pass("closed sockets: ok") - e = pcall(socket.select, "wrong", 1, 0.1) - assert(e == false) - e = pcall(socket.select, {}, 1, 0.1) - assert(e == false) - pass("invalid input: ok") +test("mixed patterns") + +function test_mixed(len) + reconnect() + local inter = math.ceil(len/4) + local p1 = "unix " .. string.rep("x", inter) .. "line\n" + local p2 = "dos " .. string.rep("y", inter) .. "line\r\n" + local p3 = "raw " .. string.rep("z", inter) .. "bytes" + local p4 = "end" .. string.rep("w", inter) .. "bytes" + local bp1, bp2, bp3, bp4 + pass(len .. " byte(s) patterns") +remote (string.format("str = data:receive(%d)", + string.len(p1)+string.len(p2)+string.len(p3)+string.len(p4))) + sent, err = data:send(p1, p2, p3, p4) + if err then fail(err) end +remote "data:send(str); data:close()" + bp1, bp2, bp3, bp4, err = data:receive("*l", "*l", string.len(p3), "*a") + if err then fail(err) end + if bp1.."\n" == p1 and bp2.."\r\n" == p2 and bp3 == p3 and bp4 == p4 then + pass("patterns match") + else fail("patterns don't match") end end -test_selectbugs() + +test_mixed(1) +test_mixed(17) +test_mixed(200) +test_mixed(4091) +test_mixed(80199) +test_mixed(4091) +test_mixed(200) +test_mixed(17) +test_mixed(1) ------------------------------------------------------------------------ test("character line") @@ -202,7 +178,7 @@ function test_asciiline(len) str = str .. str10 pass(len .. " byte(s) line") remote "str = data:receive()" - err = data:send(str, "\n") + sent, err = data:send(str, "\n") if err then fail(err) end remote "data:send(str, '\\n')" back, err = data:receive() @@ -230,7 +206,7 @@ function test_rawline(len) str = str .. str10 pass(len .. " byte(s) line") remote "str = data:receive()" - err = data:send(str, "\n") + sent, err = data:send(str, "\n") if err then fail(err) end remote "data:send(str, '\\n')" back, err = data:receive() @@ -262,9 +238,9 @@ function test_raw(len) s2 = string.rep("y", len-half) pass(len .. " byte(s) block") remote (string.format("str = data:receive(%d)", len)) - err = data:send(s1) + sent, err = data:send(s1) if err then fail(err) end - err = data:send(s2) + sent, err = data:send(s2) if err then fail(err) end remote "data:send(str)" back, err = data:receive(len) @@ -304,39 +280,139 @@ test_raw(17) test_raw(1) ------------------------------------------------------------------------ -test("mixed patterns") -reconnect() +test("total timeout on receive") +function test_totaltimeoutreceive(len, tm, sl) + local str, err, total + reconnect() + pass("%d bytes, %ds total timeout, %ds pause", len, tm, sl) + remote (string.format ([[ + data:timeout(%d) + str = string.rep('a', %d) + data:send(str) + print('server: sleeping for %ds') + socket.sleep(%d) + print('server: woke up') + data:send(str) + ]], 2*tm, len, sl, sl)) + data:timeout(tm, "total") + str, err, elapsed = data:receive(2*len) + check_timeout(tm, sl, elapsed, err, "receive", "total", + string.len(str) == 2*len) +end +test_totaltimeoutreceive(800091, 1, 3) +test_totaltimeoutreceive(800091, 2, 3) +test_totaltimeoutreceive(800091, 3, 2) +test_totaltimeoutreceive(800091, 3, 1) -function test_mixed(len) - local inter = math.floor(len/3) - local p1 = "unix " .. string.rep("x", inter) .. "line\n" - local p2 = "dos " .. string.rep("y", inter) .. "line\r\n" - local p3 = "raw " .. string.rep("z", inter) .. "bytes" - local bp1, bp2, bp3 - pass(len .. " byte(s) patterns") -remote (string.format("str = data:receive(%d)", - string.len(p1)+string.len(p2)+string.len(p3))) - err = data:send(p1, p2, p3) - if err then fail(err) end -remote "data:send(str)" - bp1, bp2, bp3, err = data:receive("*lu", "*l", string.len(p3)) - if err then fail(err) end - if bp1.."\n" == p1 and bp2.."\r\n" == p2 and bp3 == p3 then - pass("patterns match") - else fail("patterns don't match") end +------------------------------------------------------------------------ +test("total timeout on send") +function test_totaltimeoutsend(len, tm, sl) + local str, err, total + reconnect() + pass("%d bytes, %ds total timeout, %ds pause", len, tm, sl) + remote (string.format ([[ + data:timeout(%d) + str = data:receive(%d) + print('server: sleeping for %ds') + socket.sleep(%d) + print('server: woke up') + str = data:receive(%d) + ]], 2*tm, len, sl, sl, len)) + data:timeout(tm, "total") + str = string.rep("a", 2*len) + total, err, elapsed = data:send(str) + check_timeout(tm, sl, elapsed, err, "send", "total", + total == 2*len) +end +test_totaltimeoutsend(800091, 1, 3) +test_totaltimeoutsend(800091, 2, 3) +test_totaltimeoutsend(800091, 3, 2) +test_totaltimeoutsend(800091, 3, 1) + +------------------------------------------------------------------------ +test("blocking timeout on receive") +function test_blockingtimeoutreceive(len, tm, sl) + local str, err, total + reconnect() + pass("%d bytes, %ds blocking timeout, %ds pause", len, tm, sl) + remote (string.format ([[ + data:timeout(%d) + str = string.rep('a', %d) + data:send(str) + print('server: sleeping for %ds') + socket.sleep(%d) + print('server: woke up') + data:send(str) + ]], 2*tm, len, sl, sl)) + data:timeout(tm) + str, err, elapsed = data:receive(2*len) + check_timeout(tm, sl, elapsed, err, "receive", "blocking", + string.len(str) == 2*len) +end +test_blockingtimeoutreceive(800091, 1, 3) +test_blockingtimeoutreceive(800091, 2, 3) +test_blockingtimeoutreceive(800091, 3, 2) +test_blockingtimeoutreceive(800091, 3, 1) + +------------------------------------------------------------------------ +test("blocking timeout on send") +function test_blockingtimeoutsend(len, tm, sl) + local str, err, total + reconnect() + pass("%d bytes, %ds blocking timeout, %ds pause", len, tm, sl) + remote (string.format ([[ + data:timeout(%d) + str = data:receive(%d) + print('server: sleeping for %ds') + socket.sleep(%d) + print('server: woke up') + str = data:receive(%d) + ]], 2*tm, len, sl, sl, len)) + data:timeout(tm) + str = string.rep("a", 2*len) + total, err, elapsed = data:send(str) + check_timeout(tm, sl, elapsed, err, "send", "blocking", + total == 2*len) +end +test_blockingtimeoutsend(800091, 1, 3) +test_blockingtimeoutsend(800091, 2, 3) +test_blockingtimeoutsend(800091, 3, 2) +test_blockingtimeoutsend(800091, 3, 1) + +------------------------------------------------------------------------ +test("bugs") + +io.write("empty host connect: ") +function empty_connect() + if data then data:close() data = nil end + remote [[ + if data then data:close() data = nil end + data = server:accept() + ]] + data, err = socket.connect("", port) + if not data then + pass("ok") + data = socket.connect(host, port) + else fail("should not have connected!") end end -test_mixed(1) -test_mixed(17) -test_mixed(200) -test_mixed(4091) -test_mixed(80199) -test_mixed(800000) -test_mixed(80199) -test_mixed(4091) -test_mixed(200) -test_mixed(17) -test_mixed(1) +empty_connect() + +-- io.write("active close: ") +function active_close() + reconnect() + if socket._isclosed(data) then fail("should not be closed") end + data:close() + if not socket._isclosed(data) then fail("should be closed") end + data = nil + local udp = socket.udp() + if socket._isclosed(udp) then fail("should not be closed") end + udp:close() + if not socket._isclosed(udp) then fail("should be closed") end + pass("ok") +end + +-- active_close() ------------------------------------------------------------------------ test("closed connection detection") @@ -363,7 +439,7 @@ function test_closed() data:close() data = nil ]] - err, total = data:send(string.rep("ugauga", 100000)) + total, err = data:send(string.rep("ugauga", 100000)) if not err then pass("failed: output buffer is at least %d bytes long!", total) elseif err ~= "closed" then @@ -376,106 +452,26 @@ end test_closed() ------------------------------------------------------------------------ -test("return timeout on receive") -function test_blockingtimeoutreceive(len, tm, sl) - local str, err, total - reconnect() - pass("%d bytes, %ds return timeout, %ds pause", len, tm, sl) - remote (string.format ([[ - data:timeout(%d) - str = string.rep('a', %d) - data:send(str) - print('server: sleeping for %ds') - socket._sleep(%d) - print('server: woke up') - data:send(str) - ]], 2*tm, len, sl, sl)) - data:timeout(tm, "return") - str, err, elapsed = data:receive(2*len) - check_timeout(tm, sl, elapsed, err, "receive", "return", - string.len(str) == 2*len) +test("select function") +function test_selectbugs() + local r, s, e = socket.select(nil, nil, 0.1) + assert(type(r) == "table" and type(s) == "table" and e == "timeout") + pass("both nil: ok") + local udp = socket.udp() + udp:close() + r, s, e = socket.select({ udp }, { udp }, 0.1) + assert(type(r) == "table" and type(s) == "table" and e == "timeout") + pass("closed sockets: ok") + e = pcall(socket.select, "wrong", 1, 0.1) + assert(e == false) + e = pcall(socket.select, {}, 1, 0.1) + assert(e == false) + pass("invalid input: ok") end -test_blockingtimeoutreceive(800091, 1, 3) -test_blockingtimeoutreceive(800091, 2, 3) -test_blockingtimeoutreceive(800091, 3, 2) -test_blockingtimeoutreceive(800091, 3, 1) ------------------------------------------------------------------------- -test("return timeout on send") -function test_returntimeoutsend(len, tm, sl) - local str, err, total - reconnect() - pass("%d bytes, %ds return timeout, %ds pause", len, tm, sl) - remote (string.format ([[ - data:timeout(%d) - str = data:receive(%d) - print('server: sleeping for %ds') - socket._sleep(%d) - print('server: woke up') - str = data:receive(%d) - ]], 2*tm, len, sl, sl, len)) - data:timeout(tm, "return") - str = string.rep("a", 2*len) - err, total, elapsed = data:send(str) - check_timeout(tm, sl, elapsed, err, "send", "return", - total == 2*len) -end -test_returntimeoutsend(800091, 1, 3) -test_returntimeoutsend(800091, 2, 3) -test_returntimeoutsend(800091, 3, 2) -test_returntimeoutsend(800091, 3, 1) +-- test_selectbugs() ------------------------------------------------------------------------- -test("blocking timeout on receive") -function test_blockingtimeoutreceive(len, tm, sl) - local str, err, total - reconnect() - pass("%d bytes, %ds blocking timeout, %ds pause", len, tm, sl) - remote (string.format ([[ - data:timeout(%d) - str = string.rep('a', %d) - data:send(str) - print('server: sleeping for %ds') - socket._sleep(%d) - print('server: woke up') - data:send(str) - ]], 2*tm, len, sl, sl)) - data:timeout(tm) - str, err, elapsed = data:receive(2*len) - check_timeout(tm, sl, elapsed, err, "receive", "blocking", - string.len(str) == 2*len) -end -test_blockingtimeoutreceive(800091, 1, 3) -test_blockingtimeoutreceive(800091, 2, 3) -test_blockingtimeoutreceive(800091, 3, 2) -test_blockingtimeoutreceive(800091, 3, 1) ------------------------------------------------------------------------- -test("blocking timeout on send") -function test_blockingtimeoutsend(len, tm, sl) - local str, err, total - reconnect() - pass("%d bytes, %ds blocking timeout, %ds pause", len, tm, sl) - remote (string.format ([[ - data:timeout(%d) - str = data:receive(%d) - print('server: sleeping for %ds') - socket._sleep(%d) - print('server: woke up') - str = data:receive(%d) - ]], 2*tm, len, sl, sl, len)) - data:timeout(tm) - str = string.rep("a", 2*len) - err, total, elapsed = data:send(str) - check_timeout(tm, sl, elapsed, err, "send", "blocking", - total == 2*len) -end -test_blockingtimeoutsend(800091, 1, 3) -test_blockingtimeoutsend(800091, 2, 3) -test_blockingtimeoutsend(800091, 3, 2) -test_blockingtimeoutsend(800091, 3, 1) - ------------------------------------------------------------------------- -test(string.format("done in %.2fs", socket._time() - start)) +test(string.format("done in %.2fs", socket.time() - start)) diff --git a/test/testsrvr.lua b/test/testsrvr.lua index fb77ea5..3c40840 100644 --- a/test/testsrvr.lua +++ b/test/testsrvr.lua @@ -13,12 +13,13 @@ while 1 do print("server: closing connection...") break end - error = control:send("\n") + sent, error = control:send("\n") if error then control:close() print("server: closing connection...") break end + print(command); (loadstring(command))() end end diff --git a/test/tftptest.lua b/test/tftptest.lua index a435ad4..a478ed8 100644 --- a/test/tftptest.lua +++ b/test/tftptest.lua @@ -1,5 +1,5 @@ -- load tftpclnt.lua -dofile("tftpclnt.lua") +dofile("tftp.lua") -- needs tftp server running on localhost, with root pointing to -- a directory with index.html in it @@ -13,11 +13,8 @@ function readfile(file) end host = host or "localhost" -print("downloading") -err = tftp_get(host, 69, "index.html", "index.got") +retrieved, err = socket.tftp.get("tftp://" .. host .."/index.html") assert(not err, err) original = readfile("test/index.html") -retrieved = readfile("index.got") -os.remove("index.got") assert(original == retrieved, "files differ!") print("passed")