Porting to LUA 5.0 final

This commit is contained in:
Diego Nehab 2003-05-25 01:54:13 +00:00
parent c1ef3e7103
commit 0f6c8d50a9
32 changed files with 1539 additions and 1128 deletions

25
NEW
View File

@ -1,5 +1,20 @@
Socket structures are independent All functions provided by the library are in the namespace "socket".
UDPBUFFERSIZE is now internal Functions such as send/receive/timeout/close etc do not exist in the
Better treatment of closed connections: test!!! namespace. They are now only available as methods of the appropriate
HTTP post now deals with 1xx codes objects.
connect, bind etc only try first address returned by resolver
Object has been changed to become more uniform. First create an object for
a given domain/family and protocol. Then connect or bind if needed. Then
use IO functions.
All functions return a non-nil value as first return value if successful.
All functions return nil followed by error message in case of error.
WARNING: The send function was affected.
Better error messages and parameter checking.
UDP connected udp sockets can break association with peer by calling
setpeername with address "*".
socket.sleep and socket.time are now part of the library and are
supported.

131
etc/tftp.lua Normal file
View File

@ -0,0 +1,131 @@
-----------------------------------------------------------------------------
-- TFTP support for the Lua language
-- LuaSocket 1.5 toolkit.
-- Author: Diego Nehab
-- Conforming to: RFC 783, LTN7
-- RCS ID: $Id$
-----------------------------------------------------------------------------
local Public, Private = {}, {}
local socket = _G[LUASOCKET_LIBNAME] -- get LuaSocket namespace
socket.tftp = Public -- create tftp sub namespace
-----------------------------------------------------------------------------
-- Program constants
-----------------------------------------------------------------------------
local char = string.char
local byte = string.byte
Public.PORT = 69
Private.OP_RRQ = 1
Private.OP_WRQ = 2
Private.OP_DATA = 3
Private.OP_ACK = 4
Private.OP_ERROR = 5
Private.OP_INV = {"RRQ", "WRQ", "DATA", "ACK", "ERROR"}
-----------------------------------------------------------------------------
-- Packet creation functions
-----------------------------------------------------------------------------
function Private.RRQ(source, mode)
return char(0, Private.OP_RRQ) .. source .. char(0) .. mode .. char(0)
end
function Private.WRQ(source, mode)
return char(0, Private.OP_RRQ) .. source .. char(0) .. mode .. char(0)
end
function Private.ACK(block)
local low, high
low = math.mod(block, 256)
high = (block - low)/256
return char(0, Private.OP_ACK, high, low)
end
function Private.get_OP(dgram)
local op = byte(dgram, 1)*256 + byte(dgram, 2)
return op
end
-----------------------------------------------------------------------------
-- Packet analysis functions
-----------------------------------------------------------------------------
function Private.split_DATA(dgram)
local block = byte(dgram, 3)*256 + byte(dgram, 4)
local data = string.sub(dgram, 5)
return block, data
end
function Private.get_ERROR(dgram)
local code = byte(dgram, 3)*256 + byte(dgram, 4)
local msg
_,_, msg = string.find(dgram, "(.*)\000", 5)
return string.format("error code %d: %s", code, msg)
end
-----------------------------------------------------------------------------
-- Downloads and returns a file pointed to by url
-----------------------------------------------------------------------------
function Public.get(url)
local parsed = socket.url.parse(url, {
host = "",
port = Public.PORT,
path ="/",
scheme = "tftp"
})
if parsed.scheme ~= "tftp" then
return nil, string.format("unknown scheme '%s'", parsed.scheme)
end
local retries, dgram, sent, datahost, dataport, code
local cat = socket.concat.create()
local last = 0
local udp, err = socket.udp()
if not udp then return nil, err end
-- convert from name to ip if needed
parsed.host = socket.toip(parsed.host)
udp:timeout(1)
-- first packet gives data host/port to be used for data transfers
retries = 0
repeat
sent, err = udp:sendto(Private.RRQ(parsed.path, "octet"),
parsed.host, parsed.port)
if err then return nil, err end
dgram, datahost, dataport = udp:receivefrom()
retries = retries + 1
until dgram or datahost ~= "timeout" or retries > 5
if not dgram then return nil, datahost end
-- associate socket with data host/port
udp:setpeername(datahost, dataport)
-- process all data packets
while 1 do
-- decode packet
code = Private.get_OP(dgram)
if code == Private.OP_ERROR then
return nil, Private.get_ERROR(dgram)
end
if code ~= Private.OP_DATA then
return nil, "unhandled opcode " .. code
end
-- get data packet parts
local block, data = Private.split_DATA(dgram)
-- if not repeated, write
if block == last+1 then
cat:addstring(data)
last = block
end
-- last packet brings less than 512 bytes of data
if string.len(data) < 512 then
sent, err = udp:send(Private.ACK(block))
return cat:getresult()
end
-- get the next packet
retries = 0
repeat
sent, err = udp:send(Private.ACK(last))
if err then return err end
dgram, err = udp:receive()
retries = retries + 1
until dgram or err ~= "timeout" or retries > 5
if not dgram then return err end
end
end

View File

@ -7,7 +7,8 @@ end
host = socket.toip(host) host = socket.toip(host)
udp = socket.udp() udp = socket.udp()
print("Using host '" ..host.. "' and port " ..port.. "...") print("Using host '" ..host.. "' and port " ..port.. "...")
err = udp:sendto("anything", host, port) udp:setpeername(host, port)
sent, err = udp:send("anything")
if err then print(err) exit() end if err then print(err) exit() end
dgram, err = udp:receive() dgram, err = udp:receive()
if not dgram then print(err) exit() end if not dgram then print(err) exit() end

View File

@ -5,18 +5,18 @@ if arg then
port = arg[2] or port port = arg[2] or port
end end
print("Attempting connection to host '" ..host.. "' and port " ..port.. "...") print("Attempting connection to host '" ..host.. "' and port " ..port.. "...")
c, e = connect(host, port) c, e = socket.connect(host, port)
if not c then if not c then
print(e) print(e)
exit() os.exit()
end end
print("Connected! Please type stuff (empty line to stop):") print("Connected! Please type stuff (empty line to stop):")
l = read() l = io.read()
while l and l ~= "" and not e do while l and l ~= "" and not e do
e = c:send(l, "\n") t, e = c:send(l, "\n")
if e then if e then
print(e) print(e)
exit() os.exit()
end end
l = read() l = io.read()
end end

113
src/auxiliar.c Normal file
View File

@ -0,0 +1,113 @@
/*=========================================================================*\
* Auxiliar routines for class hierarchy manipulation
*
* RCS ID: $Id$
\*=========================================================================*/
#include "aux.h"
/*=========================================================================*\
* Internal function prototypes
\*=========================================================================*/
static void *aux_getgroupudata(lua_State *L, const char *group, int objidx);
/*=========================================================================*\
* Exported functions
\*=========================================================================*/
/*-------------------------------------------------------------------------*\
* Creates a new class. A class has methods given by the func array and the
* field 'class' tells the object class. The table 'group' list the class
* groups the object belongs to.
\*-------------------------------------------------------------------------*/
void aux_newclass(lua_State *L, const char *name, luaL_reg *func)
{
luaL_newmetatable(L, name);
lua_pushstring(L, "__index");
lua_newtable(L);
luaL_openlib(L, NULL, func, 0);
lua_pushstring(L, "class");
lua_pushstring(L, name);
lua_settable(L, -3);
lua_settable(L, -3);
lua_pushstring(L, "group");
lua_newtable(L);
lua_settable(L, -3);
lua_pop(L, 1);
}
/*-------------------------------------------------------------------------*\
* Add group to object list of groups.
\*-------------------------------------------------------------------------*/
void aux_add2group(lua_State *L, const char *name, const char *group)
{
luaL_getmetatable(L, name);
lua_pushstring(L, "group");
lua_gettable(L, -2);
lua_pushstring(L, group);
lua_pushnumber(L, 1);
lua_settable(L, -3);
lua_pop(L, 2);
}
/*-------------------------------------------------------------------------*\
* Get a userdata making sure the object belongs to a given class.
\*-------------------------------------------------------------------------*/
void *aux_checkclass(lua_State *L, const char *name, int objidx)
{
void *data = luaL_checkudata(L, objidx, name);
if (!data) {
char msg[45];
sprintf(msg, "%.35s expected", name);
luaL_argerror(L, objidx, msg);
}
return data;
}
/*-------------------------------------------------------------------------*\
* Get a userdata making sure the object belongs to a given group.
\*-------------------------------------------------------------------------*/
void *aux_checkgroup(lua_State *L, const char *group, int objidx)
{
void *data = aux_getgroupudata(L, group, objidx);
if (!data) {
char msg[45];
sprintf(msg, "%.35s expected", group);
luaL_argerror(L, objidx, msg);
}
return data;
}
/*-------------------------------------------------------------------------*\
* Set object class.
\*-------------------------------------------------------------------------*/
void aux_setclass(lua_State *L, const char *name, int objidx)
{
luaL_getmetatable(L, name);
if (objidx < 0) objidx--;
lua_setmetatable(L, objidx);
}
/*=========================================================================*\
* Internal functions
\*=========================================================================*/
/*-------------------------------------------------------------------------*\
* Get a userdata if object belongs to a given group.
\*-------------------------------------------------------------------------*/
static void *aux_getgroupudata(lua_State *L, const char *group, int objidx)
{
if (!lua_getmetatable(L, objidx)) return NULL;
lua_pushstring(L, "group");
lua_gettable(L, -2);
if (lua_isnil(L, -1)) {
lua_pop(L, 2);
return NULL;
}
lua_pushstring(L, group);
lua_gettable(L, -2);
if (lua_isnil(L, -1)) {
lua_pop(L, 3);
return NULL;
}
lua_pop(L, 3);
return lua_touserdata(L, objidx);
}

26
src/auxiliar.h Normal file
View File

@ -0,0 +1,26 @@
/*=========================================================================*\
* Auxiliar routines for class hierarchy manipulation
*
* RCS ID: $Id$
\*=========================================================================*/
#ifndef AUX_H
#define AUX_H
#include <lua.h>
#include <lauxlib.h>
void aux_newclass(lua_State *L, const char *name, luaL_reg *func);
void aux_add2group(lua_State *L, const char *name, const char *group);
void *aux_checkclass(lua_State *L, const char *name, int objidx);
void *aux_checkgroup(lua_State *L, const char *group, int objidx);
void aux_setclass(lua_State *L, const char *name, int objidx);
/* min and max macros */
#ifndef MIN
#define MIN(x, y) ((x) < (y) ? x : y)
#endif
#ifndef MAX
#define MAX(x, y) ((x) > (y) ? x : y)
#endif
#endif

View File

@ -1,28 +1,24 @@
/*=========================================================================*\ /*=========================================================================*\
* Buffered input/output routines * Buffered input/output routines
* Lua methods:
* send: unbuffered send using C base_send
* receive: buffered read using C base_receive
* *
* RCS ID: $Id$ * RCS ID: $Id$
\*=========================================================================*/ \*=========================================================================*/
#include <lua.h> #include <lua.h>
#include <lauxlib.h> #include <lauxlib.h>
#include "lsbuf.h" #include "error.h"
#include "aux.h"
#include "buf.h"
/*=========================================================================*\ /*=========================================================================*\
* Internal function prototypes. * Internal function prototypes
\*=========================================================================*/ \*=========================================================================*/
static int sendraw(lua_State *L, p_buf buf, cchar *data, size_t len,
size_t *done);
static int recvraw(lua_State *L, p_buf buf, size_t wanted); static int recvraw(lua_State *L, p_buf buf, size_t wanted);
static int recvdosline(lua_State *L, p_buf buf); static int recvline(lua_State *L, p_buf buf);
static int recvunixline(lua_State *L, p_buf buf);
static int recvall(lua_State *L, p_buf buf); static int recvall(lua_State *L, p_buf buf);
static int buf_get(p_buf buf, const char **data, size_t *count);
static int buf_contents(lua_State *L, p_buf buf, cchar **data, size_t *len); static void buf_skip(p_buf buf, size_t count);
static void buf_skip(lua_State *L, p_buf buf, size_t len); static int sendraw(p_buf buf, const char *data, size_t count, size_t *sent);
/*=========================================================================*\ /*=========================================================================*\
* Exported functions * Exported functions
@ -37,98 +33,69 @@ void buf_open(lua_State *L)
/*-------------------------------------------------------------------------*\ /*-------------------------------------------------------------------------*\
* Initializes C structure * Initializes C structure
* Input
* buf: buffer structure to initialize
* base: socket object to associate with buffer structure
\*-------------------------------------------------------------------------*/ \*-------------------------------------------------------------------------*/
void buf_init(lua_State *L, p_buf buf, p_base base) void buf_init(p_buf buf, p_io io, p_tm tm)
{ {
(void) L; buf->first = buf->last = 0;
buf->buf_first = buf->buf_last = 0; buf->io = io;
buf->buf_base = base; buf->tm = tm;
} }
/*-------------------------------------------------------------------------*\ /*-------------------------------------------------------------------------*\
* Send data through buffered object * Send data through buffered object
* Input
* buf: buffer structure to be used
* Lua Input: self, a_1 [, a_2, a_3 ... a_n]
* self: socket object
* a_i: strings to be sent.
* Lua Returns
* On success: nil, followed by the total number of bytes sent
* On error: error message
\*-------------------------------------------------------------------------*/ \*-------------------------------------------------------------------------*/
int buf_send(lua_State *L, p_buf buf) int buf_meth_send(lua_State *L, p_buf buf)
{ {
int top = lua_gettop(L); int top = lua_gettop(L);
size_t total = 0; size_t total = 0;
int err = PRIV_DONE; int arg, err = IO_DONE;
int arg; p_tm tm = buf->tm;
p_base base = buf->buf_base; tm_markstart(tm);
tm_markstart(&base->base_tm);
for (arg = 2; arg <= top; arg++) { /* first arg is socket object */ for (arg = 2; arg <= top; arg++) { /* first arg is socket object */
size_t done, len; size_t sent, count;
cchar *data = luaL_optlstring(L, arg, NULL, &len); const char *data = luaL_optlstring(L, arg, NULL, &count);
if (!data || err != PRIV_DONE) break; if (!data || err != IO_DONE) break;
err = sendraw(L, buf, data, len, &done); err = sendraw(buf, data, count, &sent);
total += done; total += sent;
} }
priv_pusherror(L, err);
lua_pushnumber(L, total); lua_pushnumber(L, total);
error_push(L, err);
#ifdef LUASOCKET_DEBUG #ifdef LUASOCKET_DEBUG
/* push time elapsed during operation as the last return value */ /* push time elapsed during operation as the last return value */
lua_pushnumber(L, tm_getelapsed(&base->base_tm)/1000.0); lua_pushnumber(L, (tm_gettime() - tm_getstart(tm))/1000.0);
#endif #endif
return lua_gettop(L) - top; return lua_gettop(L) - top;
} }
/*-------------------------------------------------------------------------*\ /*-------------------------------------------------------------------------*\
* Receive data from a buffered object * Receive data from a buffered object
* Input
* buf: buffer structure to be used
* Lua Input: self [pat_1, pat_2 ... pat_n]
* self: socket object
* pat_i: may be one of the following
* "*l": reads a text line, defined as a string of caracters terminates
* by a LF character, preceded or not by a CR character. This is
* the default pattern
* "*lu": reads a text line, terminanted by a CR character only. (Unix mode)
* "*a": reads until connection closed
* number: reads 'number' characters from the socket object
* Lua Returns
* On success: one string for each pattern
* On error: all strings for which there was no error, followed by one
* nil value for the remaining strings, followed by an error code
\*-------------------------------------------------------------------------*/ \*-------------------------------------------------------------------------*/
int buf_receive(lua_State *L, p_buf buf) int buf_meth_receive(lua_State *L, p_buf buf)
{ {
int top = lua_gettop(L); int top = lua_gettop(L);
int arg, err = PRIV_DONE; int arg, err = IO_DONE;
p_base base = buf->buf_base; p_tm tm = buf->tm;
tm_markstart(&base->base_tm); tm_markstart(tm);
/* push default pattern if need be */ /* push default pattern if need be */
if (top < 2) { if (top < 2) {
lua_pushstring(L, "*l"); lua_pushstring(L, "*l");
top++; top++;
} }
/* make sure we have enough stack space */ /* make sure we have enough stack space for all returns */
luaL_checkstack(L, top+LUA_MINSTACK, "too many arguments"); luaL_checkstack(L, top+LUA_MINSTACK, "too many arguments");
/* receive all patterns */ /* receive all patterns */
for (arg = 2; arg <= top && err == PRIV_DONE; arg++) { for (arg = 2; arg <= top && err == IO_DONE; arg++) {
if (!lua_isnumber(L, arg)) { if (!lua_isnumber(L, arg)) {
static cchar *patternnames[] = {"*l", "*lu", "*a", "*w", NULL}; static const char *patternnames[] = {"*l", "*a", NULL};
cchar *pattern = luaL_optstring(L, arg, NULL); const char *pattern = lua_isnil(L, arg) ?
"*l" : luaL_checkstring(L, arg);
/* get next pattern */ /* get next pattern */
switch (luaL_findstring(pattern, patternnames)) { switch (luaL_findstring(pattern, patternnames)) {
case 0: /* DOS line pattern */ case 0: /* line pattern */
err = recvdosline(L, buf); break; err = recvline(L, buf); break;
case 1: /* Unix line pattern */ case 1: /* until closed pattern */
err = recvunixline(L, buf); break; err = recvall(L, buf);
case 2: /* Until closed pattern */ if (err == IO_CLOSED) err = IO_DONE;
err = recvall(L, buf); break;
case 3: /* Word pattern */
luaL_argcheck(L, 0, arg, "word patterns are deprecated");
break; break;
default: /* else it is an error */ default: /* else it is an error */
luaL_argcheck(L, 0, arg, "invalid receive pattern"); luaL_argcheck(L, 0, arg, "invalid receive pattern");
@ -140,25 +107,20 @@ int buf_receive(lua_State *L, p_buf buf)
/* push nil for each pattern after an error */ /* push nil for each pattern after an error */
for ( ; arg <= top; arg++) lua_pushnil(L); for ( ; arg <= top; arg++) lua_pushnil(L);
/* last return is an error code */ /* last return is an error code */
priv_pusherror(L, err); error_push(L, err);
#ifdef LUASOCKET_DEBUG #ifdef LUASOCKET_DEBUG
/* push time elapsed during operation as the last return value */ /* push time elapsed during operation as the last return value */
lua_pushnumber(L, tm_getelapsed(&base->base_tm)/1000.0); lua_pushnumber(L, (tm_gettime() - tm_getstart(tm))/1000.0);
#endif #endif
return lua_gettop(L) - top; return lua_gettop(L) - top;
} }
/*-------------------------------------------------------------------------*\ /*-------------------------------------------------------------------------*\
* Determines if there is any data in the read buffer * Determines if there is any data in the read buffer
* Input
* buf: buffer structure to be used
* Returns
* 1 if empty, 0 if there is data
\*-------------------------------------------------------------------------*/ \*-------------------------------------------------------------------------*/
int buf_isempty(lua_State *L, p_buf buf) int buf_isempty(p_buf buf)
{ {
(void) L; return buf->first >= buf->last;
return buf->buf_first >= buf->buf_last;
} }
/*=========================================================================*\ /*=========================================================================*\
@ -166,24 +128,16 @@ int buf_isempty(lua_State *L, p_buf buf)
\*=========================================================================*/ \*=========================================================================*/
/*-------------------------------------------------------------------------*\ /*-------------------------------------------------------------------------*\
* Sends a raw block of data through a buffered object. * Sends a raw block of data through a buffered object.
* Input
* buf: buffer structure to be used
* data: data to be sent
* len: number of bytes to send
* Output
* sent: number of bytes sent
* Returns
* operation error code.
\*-------------------------------------------------------------------------*/ \*-------------------------------------------------------------------------*/
static int sendraw(lua_State *L, p_buf buf, cchar *data, size_t len, static int sendraw(p_buf buf, const char *data, size_t count, size_t *sent)
size_t *sent)
{ {
p_base base = buf->buf_base; p_io io = buf->io;
p_tm tm = buf->tm;
size_t total = 0; size_t total = 0;
int err = PRIV_DONE; int err = IO_DONE;
while (total < len && err == PRIV_DONE) { while (total < count && err == IO_DONE) {
size_t done; size_t done;
err = base->base_send(L, base, data + total, len - total, &done); err = io->send(io->ctx, data+total, count-total, &done, tm_get(tm));
total += done; total += done;
} }
*sent = total; *sent = total;
@ -192,25 +146,21 @@ static int sendraw(lua_State *L, p_buf buf, cchar *data, size_t len,
/*-------------------------------------------------------------------------*\ /*-------------------------------------------------------------------------*\
* Reads a raw block of data from a buffered object. * Reads a raw block of data from a buffered object.
* Input
* buf: buffer structure
* wanted: number of bytes to be read
* Returns
* operation error code.
\*-------------------------------------------------------------------------*/ \*-------------------------------------------------------------------------*/
static int recvraw(lua_State *L, p_buf buf, size_t wanted) static
int recvraw(lua_State *L, p_buf buf, size_t wanted)
{ {
int err = PRIV_DONE; int err = IO_DONE;
size_t total = 0; size_t total = 0;
luaL_Buffer b; luaL_Buffer b;
luaL_buffinit(L, &b); luaL_buffinit(L, &b);
while (total < wanted && err == PRIV_DONE) { while (total < wanted && err == IO_DONE) {
size_t len; cchar *data; size_t count; const char *data;
err = buf_contents(L, buf, &data, &len); err = buf_get(buf, &data, &count);
len = MIN(len, wanted - total); count = MIN(count, wanted - total);
luaL_addlstring(&b, data, len); luaL_addlstring(&b, data, count);
buf_skip(L, buf, len); buf_skip(buf, count);
total += len; total += count;
} }
luaL_pushresult(&b); luaL_pushresult(&b);
return err; return err;
@ -218,21 +168,18 @@ static int recvraw(lua_State *L, p_buf buf, size_t wanted)
/*-------------------------------------------------------------------------*\ /*-------------------------------------------------------------------------*\
* Reads everything until the connection is closed * Reads everything until the connection is closed
* Input
* buf: buffer structure
* Result
* operation error code.
\*-------------------------------------------------------------------------*/ \*-------------------------------------------------------------------------*/
static int recvall(lua_State *L, p_buf buf) static
int recvall(lua_State *L, p_buf buf)
{ {
int err = PRIV_DONE; int err = IO_DONE;
luaL_Buffer b; luaL_Buffer b;
luaL_buffinit(L, &b); luaL_buffinit(L, &b);
while (err == PRIV_DONE) { while (err == IO_DONE) {
cchar *data; size_t len; const char *data; size_t count;
err = buf_contents(L, buf, &data, &len); err = buf_get(buf, &data, &count);
luaL_addlstring(&b, data, len); luaL_addlstring(&b, data, count);
buf_skip(L, buf, len); buf_skip(buf, count);
} }
luaL_pushresult(&b); luaL_pushresult(&b);
return err; return err;
@ -241,61 +188,27 @@ static int recvall(lua_State *L, p_buf buf)
/*-------------------------------------------------------------------------*\ /*-------------------------------------------------------------------------*\
* Reads a line terminated by a CR LF pair or just by a LF. The CR and LF * Reads a line terminated by a CR LF pair or just by a LF. The CR and LF
* are not returned by the function and are discarded from the buffer. * are not returned by the function and are discarded from the buffer.
* Input
* buf: buffer structure
* Result
* operation error code. PRIV_DONE, PRIV_TIMEOUT or PRIV_CLOSED
\*-------------------------------------------------------------------------*/ \*-------------------------------------------------------------------------*/
static int recvdosline(lua_State *L, p_buf buf) static
int recvline(lua_State *L, p_buf buf)
{ {
int err = 0; int err = 0;
luaL_Buffer b; luaL_Buffer b;
luaL_buffinit(L, &b); luaL_buffinit(L, &b);
while (err == PRIV_DONE) { while (err == IO_DONE) {
size_t len, pos; cchar *data; size_t count, pos; const char *data;
err = buf_contents(L, buf, &data, &len); err = buf_get(buf, &data, &count);
pos = 0; pos = 0;
while (pos < len && data[pos] != '\n') { while (pos < count && data[pos] != '\n') {
/* we ignore all \r's */ /* we ignore all \r's */
if (data[pos] != '\r') luaL_putchar(&b, data[pos]); if (data[pos] != '\r') luaL_putchar(&b, data[pos]);
pos++; pos++;
} }
if (pos < len) { /* found '\n' */ if (pos < count) { /* found '\n' */
buf_skip(L, buf, pos+1); /* skip '\n' too */ buf_skip(buf, pos+1); /* skip '\n' too */
break; /* we are done */ break; /* we are done */
} else /* reached the end of the buffer */ } else /* reached the end of the buffer */
buf_skip(L, buf, pos); buf_skip(buf, pos);
}
luaL_pushresult(&b);
return err;
}
/*-------------------------------------------------------------------------*\
* Reads a line terminated by a LF character, which is not returned by
* the function, and is skipped in the buffer.
* Input
* buf: buffer structure
* Returns
* operation error code. PRIV_DONE, PRIV_TIMEOUT or PRIV_CLOSED
\*-------------------------------------------------------------------------*/
static int recvunixline(lua_State *L, p_buf buf)
{
int err = PRIV_DONE;
luaL_Buffer b;
luaL_buffinit(L, &b);
while (err == 0) {
size_t pos, len; cchar *data;
err = buf_contents(L, buf, &data, &len);
pos = 0;
while (pos < len && data[pos] != '\n') {
luaL_putchar(&b, data[pos]);
pos++;
}
if (pos < len) { /* found '\n' */
buf_skip(L, buf, pos+1); /* skip '\n' too */
break; /* we are done */
} else /* reached the end of the buffer */
buf_skip(L, buf, pos);
} }
luaL_pushresult(&b); luaL_pushresult(&b);
return err; return err;
@ -303,38 +216,32 @@ static int recvunixline(lua_State *L, p_buf buf)
/*-------------------------------------------------------------------------*\ /*-------------------------------------------------------------------------*\
* Skips a given number of bytes in read buffer * Skips a given number of bytes in read buffer
* Input
* buf: buffer structure
* len: number of bytes to skip
\*-------------------------------------------------------------------------*/ \*-------------------------------------------------------------------------*/
static void buf_skip(lua_State *L, p_buf buf, size_t len) static
void buf_skip(p_buf buf, size_t count)
{ {
buf->buf_first += len; buf->first += count;
if (buf_isempty(L, buf)) buf->buf_first = buf->buf_last = 0; if (buf_isempty(buf))
buf->first = buf->last = 0;
} }
/*-------------------------------------------------------------------------*\ /*-------------------------------------------------------------------------*\
* Return any data available in buffer, or get more data from transport layer * Return any data available in buffer, or get more data from transport layer
* if buffer is empty. * if buffer is empty.
* Input
* buf: buffer structure
* Output
* data: pointer to buffer start
* len: buffer buffer length
* Returns
* PRIV_DONE, PRIV_CLOSED, PRIV_TIMEOUT ...
\*-------------------------------------------------------------------------*/ \*-------------------------------------------------------------------------*/
static int buf_contents(lua_State *L, p_buf buf, cchar **data, size_t *len) static
int buf_get(p_buf buf, const char **data, size_t *count)
{ {
int err = PRIV_DONE; int err = IO_DONE;
p_base base = buf->buf_base; p_io io = buf->io;
if (buf_isempty(L, buf)) { p_tm tm = buf->tm;
size_t done; if (buf_isempty(buf)) {
err = base->base_receive(L, base, buf->buf_data, BUF_SIZE, &done); size_t got;
buf->buf_first = 0; err = io->recv(io->ctx, buf->data, BUF_SIZE, &got, tm_get(tm));
buf->buf_last = done; buf->first = 0;
buf->last = got;
} }
*len = buf->buf_last - buf->buf_first; *count = buf->last - buf->first;
*data = buf->buf_data + buf->buf_first; *data = buf->data + buf->first;
return err; return err;
} }

View File

@ -3,11 +3,12 @@
* *
* RCS ID: $Id$ * RCS ID: $Id$
\*=========================================================================*/ \*=========================================================================*/
#ifndef BUF_H_ #ifndef BUF_H
#define BUF_H_ #define BUF_H
#include <lua.h> #include <lua.h>
#include "lsbase.h" #include "io.h"
#include "tm.h"
/* buffer size in bytes */ /* buffer size in bytes */
#define BUF_SIZE 8192 #define BUF_SIZE 8192
@ -15,10 +16,11 @@
/*-------------------------------------------------------------------------*\ /*-------------------------------------------------------------------------*\
* Buffer control structure * Buffer control structure
\*-------------------------------------------------------------------------*/ \*-------------------------------------------------------------------------*/
typedef struct t_buf_tag { typedef struct t_buf_ {
size_t buf_first, buf_last; p_io io; /* IO driver used for this buffer */
char buf_data[BUF_SIZE]; p_tm tm; /* timeout management for this buffer */
p_base buf_base; size_t first, last; /* index of first and last bytes of stored data */
char data[BUF_SIZE]; /* storage space for buffer data */
} t_buf; } t_buf;
typedef t_buf *p_buf; typedef t_buf *p_buf;
@ -26,9 +28,9 @@ typedef t_buf *p_buf;
* Exported functions * Exported functions
\*-------------------------------------------------------------------------*/ \*-------------------------------------------------------------------------*/
void buf_open(lua_State *L); void buf_open(lua_State *L);
void buf_init(lua_State *L, p_buf buf, p_base base); void buf_init(p_buf buf, p_io io, p_tm tm);
int buf_send(lua_State *L, p_buf buf); int buf_meth_send(lua_State *L, p_buf buf);
int buf_receive(lua_State *L, p_buf buf); int buf_meth_receive(lua_State *L, p_buf buf);
int buf_isempty(lua_State *L, p_buf buf); int buf_isempty(p_buf buf);
#endif /* BUF_H_ */ #endif /* BUF_H */

View File

@ -7,7 +7,8 @@
----------------------------------------------------------------------------- -----------------------------------------------------------------------------
local Public, Private = {}, {} local Public, Private = {}, {}
socket.ftp = Public local socket = _G[LUASOCKET_LIBNAME] -- get LuaSocket namespace
socket.ftp = Public -- create ftp sub namespace
----------------------------------------------------------------------------- -----------------------------------------------------------------------------
-- Program constants -- Program constants
@ -22,6 +23,33 @@ Public.EMAIL = "anonymous@anonymous.org"
-- block size used in transfers -- block size used in transfers
Public.BLOCKSIZE = 8192 Public.BLOCKSIZE = 8192
-----------------------------------------------------------------------------
-- Tries to get a pattern from the server and closes socket on error
-- sock: socket connected to the server
-- pattern: pattern to receive
-- Returns
-- received pattern on success
-- nil followed by error message on error
-----------------------------------------------------------------------------
function Private.try_receive(sock, pattern)
local data, err = sock:receive(pattern)
if not data then sock:close() end
return data, err
end
-----------------------------------------------------------------------------
-- Tries to send data to the server and closes socket on error
-- sock: socket connected to the server
-- data: data to send
-- Returns
-- err: error message if any, nil if successfull
-----------------------------------------------------------------------------
function Private.try_send(sock, data)
local sent, err = sock:send(data)
if not sent then sock:close() end
return err
end
----------------------------------------------------------------------------- -----------------------------------------------------------------------------
-- Tries to send DOS mode lines. Closes socket on error. -- Tries to send DOS mode lines. Closes socket on error.
-- Input -- Input
@ -31,24 +59,7 @@ Public.BLOCKSIZE = 8192
-- err: message in case of error, nil if successfull -- err: message in case of error, nil if successfull
----------------------------------------------------------------------------- -----------------------------------------------------------------------------
function Private.try_sendline(sock, line) function Private.try_sendline(sock, line)
local err = sock:send(line .. "\r\n") return Private.try_send(sock, line .. "\r\n")
if err then sock:close() end
return err
end
-----------------------------------------------------------------------------
-- Tries to get a pattern from the server and closes socket on error
-- sock: socket connected to the server
-- ...: pattern to receive
-- Returns
-- ...: received pattern
-- err: error message if any
-----------------------------------------------------------------------------
function Private.try_receive(...)
local sock = arg[1]
local data, err = sock.receive(unpack(arg))
if err then sock:close() end
return data, err
end end
----------------------------------------------------------------------------- -----------------------------------------------------------------------------
@ -307,20 +318,20 @@ end
-- nil if successfull, or an error message in case of error -- nil if successfull, or an error message in case of error
----------------------------------------------------------------------------- -----------------------------------------------------------------------------
function Private.send_indirect(data, send_cb, chunk, size) function Private.send_indirect(data, send_cb, chunk, size)
local sent, err local total, sent, err
sent = 0 total = 0
while 1 do while 1 do
if type(chunk) ~= "string" or type(size) ~= "number" then if type(chunk) ~= "string" or type(size) ~= "number" then
data:close() data:close()
if not chunk and type(size) == "string" then return size if not chunk and type(size) == "string" then return size
else return "invalid callback return" end else return "invalid callback return" end
end end
err = data:send(chunk) sent, err = data:send(chunk)
if err then if err then
data:close() data:close()
return err return err
end end
sent = sent + string.len(chunk) total = total + sent
if sent >= size then break end if sent >= size then break end
chunk, size = send_cb() chunk, size = send_cb()
end end

View File

@ -7,7 +7,8 @@
----------------------------------------------------------------------------- -----------------------------------------------------------------------------
local Public, Private = {}, {} local Public, Private = {}, {}
socket.http = Public local socket = _G[LUASOCKET_LIBNAME] -- get LuaSocket namespace
socket.http = Public -- create http sub namespace
----------------------------------------------------------------------------- -----------------------------------------------------------------------------
-- Program constants -- Program constants
@ -24,19 +25,15 @@ Public.BLOCKSIZE = 8192
----------------------------------------------------------------------------- -----------------------------------------------------------------------------
-- Tries to get a pattern from the server and closes socket on error -- Tries to get a pattern from the server and closes socket on error
-- sock: socket connected to the server -- sock: socket connected to the server
-- ...: pattern to receive -- pattern: pattern to receive
-- Returns -- Returns
-- ...: received pattern -- received pattern on success
-- err: error message if any -- nil followed by error message on error
----------------------------------------------------------------------------- -----------------------------------------------------------------------------
function Private.try_receive(...) function Private.try_receive(sock, pattern)
local sock = arg[1] local data, err = sock:receive(pattern)
local data, err = sock.receive(unpack(arg)) if not data then sock:close() end
if err then return data, err
sock:close()
return nil, err
end
return data
end end
----------------------------------------------------------------------------- -----------------------------------------------------------------------------
@ -47,8 +44,8 @@ end
-- err: error message if any, nil if successfull -- err: error message if any, nil if successfull
----------------------------------------------------------------------------- -----------------------------------------------------------------------------
function Private.try_send(sock, data) function Private.try_send(sock, data)
local err = sock:send(data) local sent, err = sock:send(data)
if err then sock:close() end if not sent then sock:close() end
return err return err
end end
@ -285,21 +282,21 @@ end
-- nil if successfull, or an error message in case of error -- nil if successfull, or an error message in case of error
----------------------------------------------------------------------------- -----------------------------------------------------------------------------
function Private.send_indirect(data, send_cb, chunk, size) function Private.send_indirect(data, send_cb, chunk, size)
local sent, err local total, sent, err
sent = 0 total = 0
while 1 do while 1 do
if type(chunk) ~= "string" or type(size) ~= "number" then if type(chunk) ~= "string" or type(size) ~= "number" then
data:close() data:close()
if not chunk and type(size) == "string" then return size if not chunk and type(size) == "string" then return size
else return "invalid callback return" end else return "invalid callback return" end
end end
err = data:send(chunk) sent, err = data:send(chunk)
if err then if err then
data:close() data:close()
return err return err
end end
sent = sent + string.len(chunk) total = total + sent
if sent >= size then break end if total >= size then break end
chunk, size = send_cb() chunk, size = send_cb()
end end
end end

View File

@ -1,12 +1,5 @@
/*=========================================================================*\ /*=========================================================================*\
* Internet domain class: inherits from the Socket class, and implement * Internet domain functions
* a few methods shared by all internet related objects
* Lua methods:
* getpeername: gets socket peer ip address and port
* getsockname: gets local socket ip address and port
* Global Lua fuctions:
* toip: gets resolver info on host name
* tohostname: gets resolver info on dotted-quad
* *
* RCS ID: $Id$ * RCS ID: $Id$
\*=========================================================================*/ \*=========================================================================*/
@ -15,23 +8,27 @@
#include <lua.h> #include <lua.h>
#include <lauxlib.h> #include <lauxlib.h>
#include "lsinet.h" #include "luasocket.h"
#include "lssock.h" #include "inet.h"
#include "lscompat.h"
/*=========================================================================*\ /*=========================================================================*\
* Internal function prototypes. * Internal function prototypes.
\*=========================================================================*/ \*=========================================================================*/
static int inet_lua_toip(lua_State *L); static int inet_global_toip(lua_State *L);
static int inet_lua_tohostname(lua_State *L); static int inet_global_tohostname(lua_State *L);
static int inet_lua_getpeername(lua_State *L);
static int inet_lua_getsockname(lua_State *L);
static void inet_pushresolved(lua_State *L, struct hostent *hp); static void inet_pushresolved(lua_State *L, struct hostent *hp);
#ifdef COMPAT_INETATON #ifdef INET_ATON
static int inet_aton(cchar *cp, struct in_addr *inp); static int inet_aton(const char *cp, struct in_addr *inp);
#endif #endif
static luaL_reg func[] = {
{ "toip", inet_global_toip },
{ "tohostname", inet_global_tohostname },
{ NULL, NULL}
};
/*=========================================================================*\ /*=========================================================================*\
* Exported functions * Exported functions
\*=========================================================================*/ \*=========================================================================*/
@ -40,39 +37,7 @@ static int inet_aton(cchar *cp, struct in_addr *inp);
\*-------------------------------------------------------------------------*/ \*-------------------------------------------------------------------------*/
void inet_open(lua_State *L) void inet_open(lua_State *L)
{ {
lua_pushcfunction(L, inet_lua_toip); luaL_openlib(L, LUASOCKET_LIBNAME, func, 0);
priv_newglobal(L, "toip");
lua_pushcfunction(L, inet_lua_tohostname);
priv_newglobal(L, "tohostname");
priv_newglobalmethod(L, "getsockname");
priv_newglobalmethod(L, "getpeername");
}
/*-------------------------------------------------------------------------*\
* Hook lua methods to methods table.
* Input
* lsclass: class name
\*-------------------------------------------------------------------------*/
void inet_inherit(lua_State *L, cchar *lsclass)
{
unsigned int i;
static struct luaL_reg funcs[] = {
{"getsockname", inet_lua_getsockname},
{"getpeername", inet_lua_getpeername},
};
sock_inherit(L, lsclass);
for (i = 0; i < sizeof(funcs)/sizeof(funcs[0]); i++) {
lua_pushcfunction(L, funcs[i].func);
priv_setmethod(L, lsclass, funcs[i].name);
}
}
/*-------------------------------------------------------------------------*\
* Constructs the object
\*-------------------------------------------------------------------------*/
void inet_construct(lua_State *L, p_inet inet)
{
sock_construct(L, (p_sock) inet);
} }
/*=========================================================================*\ /*=========================================================================*\
@ -87,17 +52,18 @@ void inet_construct(lua_State *L, p_inet inet)
* On success: first IP address followed by a resolved table * On success: first IP address followed by a resolved table
* On error: nil, followed by an error message * On error: nil, followed by an error message
\*-------------------------------------------------------------------------*/ \*-------------------------------------------------------------------------*/
static int inet_lua_toip(lua_State *L) static int inet_global_toip(lua_State *L)
{ {
cchar *address = luaL_checkstring(L, 1); const char *address = luaL_checkstring(L, 1);
struct in_addr addr; struct in_addr addr;
struct hostent *hp; struct hostent *hp;
if (inet_aton(address, &addr)) if (inet_aton(address, &addr))
hp = gethostbyaddr((char *) &addr, sizeof(addr), AF_INET); hp = gethostbyaddr((char *) &addr, sizeof(addr), AF_INET);
else hp = gethostbyname(address); else
hp = gethostbyname(address);
if (!hp) { if (!hp) {
lua_pushnil(L); lua_pushnil(L);
lua_pushstring(L, compat_hoststrerror()); lua_pushstring(L, sock_hoststrerror());
return 2; return 2;
} }
addr = *((struct in_addr *) hp->h_addr); addr = *((struct in_addr *) hp->h_addr);
@ -115,17 +81,18 @@ static int inet_lua_toip(lua_State *L)
* On success: canonic name followed by a resolved table * On success: canonic name followed by a resolved table
* On error: nil, followed by an error message * On error: nil, followed by an error message
\*-------------------------------------------------------------------------*/ \*-------------------------------------------------------------------------*/
static int inet_lua_tohostname(lua_State *L) static int inet_global_tohostname(lua_State *L)
{ {
cchar *address = luaL_checkstring(L, 1); const char *address = luaL_checkstring(L, 1);
struct in_addr addr; struct in_addr addr;
struct hostent *hp; struct hostent *hp;
if (inet_aton(address, &addr)) if (inet_aton(address, &addr))
hp = gethostbyaddr((char *) &addr, sizeof(addr), AF_INET); hp = gethostbyaddr((char *) &addr, sizeof(addr), AF_INET);
else hp = gethostbyname(address); else
hp = gethostbyname(address);
if (!hp) { if (!hp) {
lua_pushnil(L); lua_pushnil(L);
lua_pushstring(L, compat_hoststrerror()); lua_pushstring(L, sock_hoststrerror());
return 2; return 2;
} }
lua_pushstring(L, hp->h_name); lua_pushstring(L, hp->h_name);
@ -138,18 +105,17 @@ static int inet_lua_tohostname(lua_State *L)
\*=========================================================================*/ \*=========================================================================*/
/*-------------------------------------------------------------------------*\ /*-------------------------------------------------------------------------*\
* Retrieves socket peer name * Retrieves socket peer name
* Lua Input: sock * Input:
* sock: socket * sock: socket
* Lua Returns * Lua Returns
* On success: ip address and port of peer * On success: ip address and port of peer
* On error: nil * On error: nil
\*-------------------------------------------------------------------------*/ \*-------------------------------------------------------------------------*/
static int inet_lua_getpeername(lua_State *L) int inet_meth_getpeername(lua_State *L, p_sock ps)
{ {
p_sock sock = (p_sock) lua_touserdata(L, 1);
struct sockaddr_in peer; struct sockaddr_in peer;
size_t peer_len = sizeof(peer); size_t peer_len = sizeof(peer);
if (getpeername(sock->fd, (SA *) &peer, &peer_len) < 0) { if (getpeername(*ps, (SA *) &peer, &peer_len) < 0) {
lua_pushnil(L); lua_pushnil(L);
return 1; return 1;
} }
@ -160,18 +126,17 @@ static int inet_lua_getpeername(lua_State *L)
/*-------------------------------------------------------------------------*\ /*-------------------------------------------------------------------------*\
* Retrieves socket local name * Retrieves socket local name
* Lua Input: sock * Input:
* sock: socket * sock: socket
* Lua Returns * Lua Returns
* On success: local ip address and port * On success: local ip address and port
* On error: nil * On error: nil
\*-------------------------------------------------------------------------*/ \*-------------------------------------------------------------------------*/
static int inet_lua_getsockname(lua_State *L) int inet_meth_getsockname(lua_State *L, p_sock ps)
{ {
p_sock sock = (p_sock) lua_touserdata(L, 1);
struct sockaddr_in local; struct sockaddr_in local;
size_t local_len = sizeof(local); size_t local_len = sizeof(local);
if (getsockname(sock->fd, (SA *) &local, &local_len) < 0) { if (getsockname(*ps, (SA *) &local, &local_len) < 0) {
lua_pushnil(L); lua_pushnil(L);
return 1; return 1;
} }
@ -222,47 +187,53 @@ static void inet_pushresolved(lua_State *L, struct hostent *hp)
} }
/*-------------------------------------------------------------------------*\ /*-------------------------------------------------------------------------*\
* Tries to create a TCP socket and connect to remote address (address, port) * Tries to connect to remote address (address, port)
* Input * Input
* client: socket structure to be used * ps: pointer to socket
* address: host name or ip address * address: host name or ip address
* port: port number to bind to * port: port number to bind to
* Returns * Returns
* NULL in case of success, error message otherwise * NULL in case of success, error message otherwise
\*-------------------------------------------------------------------------*/ \*-------------------------------------------------------------------------*/
cchar *inet_tryconnect(p_inet inet, cchar *address, ushort port) const char *inet_tryconnect(p_sock ps, const char *address, ushort port)
{ {
struct sockaddr_in remote; struct sockaddr_in remote;
memset(&remote, 0, sizeof(remote)); memset(&remote, 0, sizeof(remote));
remote.sin_family = AF_INET; remote.sin_family = AF_INET;
remote.sin_port = htons(port); remote.sin_port = htons(port);
if (!strlen(address) || !inet_aton(address, &remote.sin_addr)) { if (strcmp(address, "*")) {
struct hostent *hp = gethostbyname(address); if (!strlen(address) || !inet_aton(address, &remote.sin_addr)) {
struct in_addr **addr; struct hostent *hp = gethostbyname(address);
if (!hp) return compat_hoststrerror(); struct in_addr **addr;
addr = (struct in_addr **) hp->h_addr_list; remote.sin_family = AF_INET;
memcpy(&remote.sin_addr, *addr, sizeof(struct in_addr)); if (!hp) return sock_hoststrerror();
} addr = (struct in_addr **) hp->h_addr_list;
compat_setblocking(inet->fd); memcpy(&remote.sin_addr, *addr, sizeof(struct in_addr));
if (compat_connect(inet->fd, (SA *) &remote, sizeof(remote)) < 0) { }
const char *err = compat_connectstrerror(); } else remote.sin_family = AF_UNSPEC;
compat_close(inet->fd); sock_setblocking(ps);
inet->fd = COMPAT_INVALIDFD; const char *err = sock_connect(ps, (SA *) &remote, sizeof(remote));
if (err) {
sock_destroy(ps);
*ps = SOCK_INVALID;
return err; return err;
} else {
sock_setnonblocking(ps);
return NULL;
} }
compat_setnonblocking(inet->fd);
return NULL;
} }
/*-------------------------------------------------------------------------*\ /*-------------------------------------------------------------------------*\
* Tries to create a TCP socket and bind it to (address, port) * Tries to bind socket to (address, port)
* Input * Input
* sock: pointer to socket
* address: host name or ip address * address: host name or ip address
* port: port number to bind to * port: port number to bind to
* Returns * Returns
* NULL in case of success, error message otherwise * NULL in case of success, error message otherwise
\*-------------------------------------------------------------------------*/ \*-------------------------------------------------------------------------*/
cchar *inet_trybind(p_inet inet, cchar *address, ushort port) const char *inet_trybind(p_sock ps, const char *address, ushort port,
int backlog)
{ {
struct sockaddr_in local; struct sockaddr_in local;
memset(&local, 0, sizeof(local)); memset(&local, 0, sizeof(local));
@ -274,34 +245,33 @@ cchar *inet_trybind(p_inet inet, cchar *address, ushort port)
(!strlen(address) || !inet_aton(address, &local.sin_addr))) { (!strlen(address) || !inet_aton(address, &local.sin_addr))) {
struct hostent *hp = gethostbyname(address); struct hostent *hp = gethostbyname(address);
struct in_addr **addr; struct in_addr **addr;
if (!hp) return compat_hoststrerror(); if (!hp) return sock_hoststrerror();
addr = (struct in_addr **) hp->h_addr_list; addr = (struct in_addr **) hp->h_addr_list;
memcpy(&local.sin_addr, *addr, sizeof(struct in_addr)); memcpy(&local.sin_addr, *addr, sizeof(struct in_addr));
} }
compat_setblocking(inet->fd); sock_setblocking(ps);
if (compat_bind(inet->fd, (SA *) &local, sizeof(local)) < 0) { const char *err = sock_bind(ps, (SA *) &local, sizeof(local));
const char *err = compat_bindstrerror(); if (err) {
compat_close(inet->fd); sock_destroy(ps);
inet->fd = COMPAT_INVALIDFD; *ps = SOCK_INVALID;
return err; return err;
} else {
sock_setnonblocking(ps);
if (backlog > 0) sock_listen(ps, backlog);
return NULL;
} }
compat_setnonblocking(inet->fd);
return NULL;
} }
/*-------------------------------------------------------------------------*\ /*-------------------------------------------------------------------------*\
* Tries to create a new inet socket * Tries to create a new inet socket
* Input * Input
* udp: udp structure * sock: pointer to socket
* Returns * Returns
* NULL if successfull, error message on error * NULL if successfull, error message on error
\*-------------------------------------------------------------------------*/ \*-------------------------------------------------------------------------*/
cchar *inet_trysocket(p_inet inet, int type) const char *inet_trycreate(p_sock ps, int type)
{ {
if (inet->fd != COMPAT_INVALIDFD) compat_close(inet->fd); return sock_create(ps, AF_INET, type, 0);
inet->fd = compat_socket(AF_INET, type, 0);
if (inet->fd == COMPAT_INVALIDFD) return compat_socketstrerror();
else return NULL;
} }
/*-------------------------------------------------------------------------*\ /*-------------------------------------------------------------------------*\

View File

@ -1,38 +1,26 @@
/*=========================================================================*\ /*=========================================================================*\
* Internet domain class: inherits from the Socket class, and implement * Internet domain functions
* a few methods shared by all internet related objects
* *
* RCS ID: $Id$ * RCS ID: $Id$
\*=========================================================================*/ \*=========================================================================*/
#ifndef INET_H_ #ifndef INET_H
#define INET_H_ #define INET_H
#include <lua.h> #include <lua.h>
#include "lssock.h" #include "sock.h"
/* class name */
#define INET_CLASS "luasocket(inet)"
/*-------------------------------------------------------------------------*\
* Socket fields
\*-------------------------------------------------------------------------*/
#define INET_FIELDS SOCK_FIELDS
/*-------------------------------------------------------------------------*\
* Socket structure
\*-------------------------------------------------------------------------*/
typedef t_sock t_inet;
typedef t_inet *p_inet;
/*-------------------------------------------------------------------------*\ /*-------------------------------------------------------------------------*\
* Exported functions * Exported functions
\*-------------------------------------------------------------------------*/ \*-------------------------------------------------------------------------*/
void inet_open(lua_State *L); void inet_open(lua_State *L);
void inet_construct(lua_State *L, p_inet inet);
void inet_inherit(lua_State *L, cchar *lsclass);
cchar *inet_tryconnect(p_sock sock, cchar *address, ushort); const char *inet_tryconnect(p_sock ps, const char *address,
cchar *inet_trybind(p_sock sock, cchar *address, ushort); unsigned short port);
cchar *inet_trysocket(p_inet inet, int type); const char *inet_trybind(p_sock ps, const char *address,
unsigned short port, int backlog);
const char *inet_trycreate(p_sock ps, int type);
int inet_meth_getpeername(lua_State *L, p_sock ps);
int inet_meth_getsockname(lua_State *L, p_sock ps);
#endif /* INET_H_ */ #endif /* INET_H_ */

8
src/io.c Normal file
View File

@ -0,0 +1,8 @@
#include "io.h"
void io_init(p_io io, p_send send, p_recv recv, void *ctx)
{
io->send = send;
io->recv = recv;
io->ctx = ctx;
}

34
src/io.h Normal file
View File

@ -0,0 +1,34 @@
#ifndef IO_H
#define IO_H
#include "error.h"
/* interface to send function */
typedef int (*p_send) (
void *ctx, /* context needed by send */
const char *data, /* pointer to buffer with data to send */
size_t count, /* number of bytes to send from buffer */
size_t *sent, /* number of bytes sent uppon return */
int timeout /* number of miliseconds left for transmission */
);
/* interface to recv function */
typedef int (*p_recv) (
void *ctx, /* context needed by recv */
char *data, /* pointer to buffer where data will be writen */
size_t count, /* number of bytes to receive into buffer */
size_t *got, /* number of bytes received uppon return */
int timeout /* number of miliseconds left for transmission */
);
/* IO driver definition */
typedef struct t_io_ {
void *ctx; /* context needed by send/recv */
p_send send; /* send function pointer */
p_recv recv; /* receive function pointer */
} t_io;
typedef t_io *p_io;
void io_init(p_io io, p_send send, p_recv recv, void *ctx);
#endif /* IO_H */

View File

@ -23,18 +23,13 @@
* LuaSocket includes * LuaSocket includes
\*=========================================================================*/ \*=========================================================================*/
#include "luasocket.h" #include "luasocket.h"
#include "lspriv.h"
#include "lsselect.h" #include "tm.h"
#include "lscompat.h" #include "buf.h"
#include "lsbase.h" #include "sock.h"
#include "lstm.h" #include "inet.h"
#include "lsbuf.h" #include "tcp.h"
#include "lssock.h" #include "udp.h"
#include "lsinet.h"
#include "lstcpc.h"
#include "lstcps.h"
#include "lstcps.h"
#include "lsudp.h"
/*=========================================================================*\ /*=========================================================================*\
* Exported functions * Exported functions
@ -42,34 +37,29 @@
/*-------------------------------------------------------------------------*\ /*-------------------------------------------------------------------------*\
* Initializes all library modules. * Initializes all library modules.
\*-------------------------------------------------------------------------*/ \*-------------------------------------------------------------------------*/
LUASOCKET_API int lua_socketlibopen(lua_State *L) LUASOCKET_API int luaopen_socketlib(lua_State *L)
{ {
compat_open(L); /* create namespace table */
priv_open(L); lua_pushstring(L, LUASOCKET_LIBNAME);
select_open(L); lua_newtable(L);
base_open(L); #ifdef LUASOCKET_DEBUG
tm_open(L); lua_pushstring(L, "debug");
fd_open(L); lua_pushnumber(L, 1);
sock_open(L); lua_settable(L, -3);
inet_open(L);
tcpc_open(L);
buf_open(L);
tcps_open(L);
udp_open(L);
#ifdef LUASOCKET_DOFILE
lua_dofile(L, "concat.lua");
lua_dofile(L, "code.lua");
lua_dofile(L, "url.lua");
lua_dofile(L, "http.lua");
lua_dofile(L, "smtp.lua");
lua_dofile(L, "ftp.lua");
#else
#include "concat.loh"
#include "code.loh"
#include "url.loh"
#include "http.loh"
#include "smtp.loh"
#include "ftp.loh"
#endif #endif
lua_settable(L, LUA_GLOBALSINDEX);
/* make sure modules know what is our namespace */
lua_pushstring(L, "LUASOCKET_LIBNAME");
lua_pushstring(L, LUASOCKET_LIBNAME);
lua_settable(L, LUA_GLOBALSINDEX);
/* initialize all modules */
sock_open(L);
tm_open(L);
buf_open(L);
inet_open(L);
tcp_open(L);
udp_open(L);
/* load all Lua code */
lua_dofile(L, "luasocket.lua");
return 0; return 0;
} }

View File

@ -5,8 +5,8 @@
* *
* RCS ID: $Id$ * RCS ID: $Id$
\*=========================================================================*/ \*=========================================================================*/
#ifndef _LUASOCKET_H_ #ifndef LUASOCKET_H
#define _LUASOCKET_H_ #define LUASOCKET_H
/*-------------------------------------------------------------------------*\ /*-------------------------------------------------------------------------*\
* Current luasocket version * Current luasocket version
@ -28,6 +28,6 @@
/*-------------------------------------------------------------------------*\ /*-------------------------------------------------------------------------*\
* Initializes the library. * Initializes the library.
\*-------------------------------------------------------------------------*/ \*-------------------------------------------------------------------------*/
LUASOCKET_API int lua_socketlibopen(lua_State *L); LUASOCKET_API int luaopen_socketlib(lua_State *L);
#endif /* _LUASOCKET_H_ */ #endif /* LUASOCKET_H */

View File

@ -5,10 +5,10 @@ mbox = Public
function Public.split_message(message_s) function Public.split_message(message_s)
local message = {} local message = {}
message_s = string.gsub(message_s, "\r\n", "\n") message_s = string.gsub(message_s, "\r\n", "\n")
string.gsub(message_s, "^(.-\n)\n", function (h) %message.headers = h end) string.gsub(message_s, "^(.-\n)\n", function (h) message.headers = h end)
string.gsub(message_s, "^.-\n\n(.*)", function (b) %message.body = b end) string.gsub(message_s, "^.-\n\n(.*)", function (b) message.body = b end)
if not message.body then if not message.body then
string.gsub(message_s, "^\n(.*)", function (b) %message.body = b end) string.gsub(message_s, "^\n(.*)", function (b) message.body = b end)
end end
if not message.headers and not message.body then if not message.headers and not message.body then
message.headers = message_s message.headers = message_s
@ -20,7 +20,7 @@ function Public.split_headers(headers_s)
local headers = {} local headers = {}
headers_s = string.gsub(headers_s, "\r\n", "\n") headers_s = string.gsub(headers_s, "\r\n", "\n")
headers_s = string.gsub(headers_s, "\n[ ]+", " ") headers_s = string.gsub(headers_s, "\n[ ]+", " ")
string.gsub("\n" .. headers_s, "\n([^\n]+)", function (h) table.insert(%headers, h) end) string.gsub("\n" .. headers_s, "\n([^\n]+)", function (h) table.insert(headers, h) end)
return headers return headers
end end
@ -32,10 +32,10 @@ function Public.parse_header(header_s)
end end
function Public.parse_headers(headers_s) function Public.parse_headers(headers_s)
local headers_t = %Public.split_headers(headers_s) local headers_t = Public.split_headers(headers_s)
local headers = {} local headers = {}
for i = 1, table.getn(headers_t) do for i = 1, table.getn(headers_t) do
local name, value = %Public.parse_header(headers_t[i]) local name, value = Public.parse_header(headers_t[i])
if name then if name then
name = string.lower(name) name = string.lower(name)
if headers[name] then if headers[name] then
@ -73,16 +73,16 @@ function Public.split_mbox(mbox_s)
end end
function Public.parse(mbox_s) function Public.parse(mbox_s)
local mbox = %Public.split_mbox(mbox_s) local mbox = Public.split_mbox(mbox_s)
for i = 1, table.getn(mbox) do for i = 1, table.getn(mbox) do
mbox[i] = %Public.parse_message(mbox[i]) mbox[i] = Public.parse_message(mbox[i])
end end
return mbox return mbox
end end
function Public.parse_message(message_s) function Public.parse_message(message_s)
local message = {} local message = {}
message.headers, message.body = %Public.split_message(message_s) message.headers, message.body = Public.split_message(message_s)
message.headers = %Public.parse_headers(message.headers) message.headers = Public.parse_headers(message.headers)
return message return message
end end

View File

@ -7,7 +7,8 @@
----------------------------------------------------------------------------- -----------------------------------------------------------------------------
local Public, Private = {}, {} local Public, Private = {}, {}
socket.smtp = Public local socket = _G[LUASOCKET_LIBNAME] -- get LuaSocket namespace
socket.smtp = Public -- create smtp sub namespace
----------------------------------------------------------------------------- -----------------------------------------------------------------------------
-- Program constants -- Program constants
@ -23,32 +24,30 @@ Public.DOMAIN = os.getenv("SERVER_NAME") or "localhost"
Public.SERVER = "localhost" Public.SERVER = "localhost"
----------------------------------------------------------------------------- -----------------------------------------------------------------------------
-- Tries to send data through socket. Closes socket on error. -- Tries to get a pattern from the server and closes socket on error
-- Input -- sock: socket connected to the server
-- sock: server socket -- pattern: pattern to receive
-- data: string to be sent
-- Returns -- Returns
-- err: message in case of error, nil if successfull -- received pattern on success
-- nil followed by error message on error
----------------------------------------------------------------------------- -----------------------------------------------------------------------------
function Private.try_send(sock, data) function Private.try_receive(sock, pattern)
local err = sock:send(data) local data, err = sock:receive(pattern)
if err then sock:close() end if not data then sock:close() end
return err return data, err
end end
----------------------------------------------------------------------------- -----------------------------------------------------------------------------
-- Tries to get a pattern from the server and closes socket on error -- Tries to send data to the server and closes socket on error
-- sock: socket opened to the server -- sock: socket connected to the server
-- ...: pattern to receive -- data: data to send
-- Returns -- Returns
-- ...: received pattern -- err: error message if any, nil if successfull
-- err: error message if any
----------------------------------------------------------------------------- -----------------------------------------------------------------------------
function Private.try_receive(...) function Private.try_send(sock, data)
local sock = arg[1] local sent, err = sock:send(data)
local data, err = sock.receive(unpack(arg)) if not sent then sock:close() end
if err then sock:close() end return err
return data, err
end end
----------------------------------------------------------------------------- -----------------------------------------------------------------------------

222
src/tcp.c Normal file
View File

@ -0,0 +1,222 @@
/*=========================================================================*\
* TCP object
*
* RCS ID: $Id$
\*=========================================================================*/
#include <string.h>
#include <lua.h>
#include <lauxlib.h>
#include "luasocket.h"
#include "aux.h"
#include "inet.h"
#include "tcp.h"
/*=========================================================================*\
* Internal function prototypes
\*=========================================================================*/
static int tcp_global_create(lua_State *L);
static int tcp_meth_connect(lua_State *L);
static int tcp_meth_bind(lua_State *L);
static int tcp_meth_send(lua_State *L);
static int tcp_meth_getsockname(lua_State *L);
static int tcp_meth_getpeername(lua_State *L);
static int tcp_meth_receive(lua_State *L);
static int tcp_meth_accept(lua_State *L);
static int tcp_meth_close(lua_State *L);
static int tcp_meth_timeout(lua_State *L);
/* tcp object methods */
static luaL_reg tcp[] = {
{"connect", tcp_meth_connect},
{"send", tcp_meth_send},
{"receive", tcp_meth_receive},
{"bind", tcp_meth_bind},
{"accept", tcp_meth_accept},
{"setpeername", tcp_meth_connect},
{"setsockname", tcp_meth_bind},
{"getpeername", tcp_meth_getpeername},
{"getsockname", tcp_meth_getsockname},
{"timeout", tcp_meth_timeout},
{"close", tcp_meth_close},
{NULL, NULL}
};
/* functions in library namespace */
static luaL_reg func[] = {
{"tcp", tcp_global_create},
{NULL, NULL}
};
/*-------------------------------------------------------------------------*\
* Initializes module
\*-------------------------------------------------------------------------*/
void tcp_open(lua_State *L)
{
/* create classes */
aux_newclass(L, "tcp{master}", tcp);
aux_newclass(L, "tcp{client}", tcp);
aux_newclass(L, "tcp{server}", tcp);
/* create class groups */
aux_add2group(L, "tcp{client}", "tcp{client, server}");
aux_add2group(L, "tcp{server}", "tcp{client, server}");
aux_add2group(L, "tcp{master}", "tcp{any}");
aux_add2group(L, "tcp{client}", "tcp{any}");
aux_add2group(L, "tcp{server}", "tcp{any}");
/* define library functions */
luaL_openlib(L, LUASOCKET_LIBNAME, func, 0);
lua_pop(L, 1);
}
/*=========================================================================*\
* Lua methods
\*=========================================================================*/
/*-------------------------------------------------------------------------*\
* Just call buffered IO methods
\*-------------------------------------------------------------------------*/
static int tcp_meth_send(lua_State *L)
{
p_tcp tcp = (p_tcp) aux_checkclass(L, "tcp{client}", 1);
return buf_meth_send(L, &tcp->buf);
}
static int tcp_meth_receive(lua_State *L)
{
p_tcp tcp = (p_tcp) aux_checkclass(L, "tcp{client}", 1);
return buf_meth_receive(L, &tcp->buf);
}
/*-------------------------------------------------------------------------*\
* Just call inet methods
\*-------------------------------------------------------------------------*/
static int tcp_meth_getpeername(lua_State *L)
{
p_tcp tcp = (p_tcp) aux_checkclass(L, "tcp{client}", 1);
return inet_meth_getpeername(L, &tcp->sock);
}
static int tcp_meth_getsockname(lua_State *L)
{
p_tcp tcp = (p_tcp) aux_checkgroup(L, "tcp{client, server}", 1);
return inet_meth_getsockname(L, &tcp->sock);
}
/*-------------------------------------------------------------------------*\
* Just call tm methods
\*-------------------------------------------------------------------------*/
static int tcp_meth_timeout(lua_State *L)
{
p_tcp tcp = (p_tcp) aux_checkgroup(L, "tcp{any}", 1);
return tm_meth_timeout(L, &tcp->tm);
}
/*-------------------------------------------------------------------------*\
* Closes socket used by object
\*-------------------------------------------------------------------------*/
static int tcp_meth_close(lua_State *L)
{
p_tcp tcp = (p_tcp) aux_checkgroup(L, "tcp{any}", 1);
sock_destroy(&tcp->sock);
return 0;
}
/*-------------------------------------------------------------------------*\
* Turns a master tcp object into a client object.
\*-------------------------------------------------------------------------*/
static int tcp_meth_connect(lua_State *L)
{
p_tcp tcp = (p_tcp) aux_checkclass(L, "tcp{master}", 1);
const char *address = luaL_checkstring(L, 2);
unsigned short port = (ushort) luaL_checknumber(L, 3);
const char *err = inet_tryconnect(&tcp->sock, address, port);
if (err) {
lua_pushnil(L);
lua_pushstring(L, err);
return 2;
}
/* turn master object into a client object */
aux_setclass(L, "tcp{client}", 1);
lua_pushnumber(L, 1);
return 1;
}
/*-------------------------------------------------------------------------*\
* Turns a master object into a server object
\*-------------------------------------------------------------------------*/
static int tcp_meth_bind(lua_State *L)
{
p_tcp tcp = (p_tcp) aux_checkclass(L, "tcp{master}", 1);
const char *address = luaL_checkstring(L, 2);
unsigned short port = (ushort) luaL_checknumber(L, 3);
int backlog = (int) luaL_optnumber(L, 4, 1);
const char *err = inet_trybind(&tcp->sock, address, port, backlog);
if (err) {
lua_pushnil(L);
lua_pushstring(L, err);
return 2;
}
/* turn master object into a server object */
aux_setclass(L, "tcp{server}", 1);
lua_pushnumber(L, 1);
return 1;
}
/*-------------------------------------------------------------------------*\
* Waits for and returns a client object attempting connection to the
* server object
\*-------------------------------------------------------------------------*/
static int tcp_meth_accept(lua_State *L)
{
struct sockaddr_in addr;
size_t addr_len = sizeof(addr);
p_tcp server = (p_tcp) aux_checkclass(L, "tcp{server}", 1);
p_tm tm = &server->tm;
p_tcp client = lua_newuserdata(L, sizeof(t_tcp));
tm_markstart(tm);
aux_setclass(L, "tcp{client}", -1);
for ( ;; ) {
sock_accept(&server->sock, &client->sock,
(SA *) &addr, &addr_len, tm_get(tm));
if (client->sock == SOCK_INVALID) {
if (tm_get(tm) == 0) {
lua_pushnil(L);
error_push(L, IO_TIMEOUT);
return 2;
}
} else break;
}
/* initialize remaining structure fields */
io_init(&client->io, (p_send) sock_send, (p_recv) sock_recv, &client->sock);
tm_init(&client->tm, -1, -1);
buf_init(&client->buf, &client->io, &client->tm);
return 1;
}
/*=========================================================================*\
* Library functions
\*=========================================================================*/
/*-------------------------------------------------------------------------*\
* Creates a master tcp object
\*-------------------------------------------------------------------------*/
int tcp_global_create(lua_State *L)
{
/* allocate tcp object */
p_tcp tcp = (p_tcp) lua_newuserdata(L, sizeof(t_tcp));
/* set its type as master object */
aux_setclass(L, "tcp{master}", -1);
/* try to allocate a system socket */
const char *err = inet_trycreate(&tcp->sock, SOCK_STREAM);
if (err) { /* get rid of object on stack and push error */
lua_pop(L, 1);
lua_pushnil(L);
lua_pushstring(L, err);
return 2;
}
/* initialize remaining structure fields */
io_init(&tcp->io, (p_send) sock_send, (p_recv) sock_recv, &tcp->sock);
tm_init(&tcp->tm, -1, -1);
buf_init(&tcp->buf, &tcp->io, &tcp->tm);
return 1;
}

20
src/tcp.h Normal file
View File

@ -0,0 +1,20 @@
#ifndef TCP_H
#define TCP_H
#include <lua.h>
#include "buf.h"
#include "tm.h"
#include "sock.h"
typedef struct t_tcp_ {
t_sock sock;
t_io io;
t_buf buf;
t_tm tm;
} t_tcp;
typedef t_tcp *p_tcp;
void tcp_open(lua_State *L);
#endif

View File

@ -1,18 +1,19 @@
/*=========================================================================*\ /*=========================================================================*\
* Timeout management functions * Timeout management functions
* Global Lua functions: * Global Lua functions:
* _sleep: (debug mode only) * _sleep
* _time: (debug mode only) * _time
* *
* RCS ID: $Id$ * RCS ID: $Id$
\*=========================================================================*/ \*=========================================================================*/
#include <stdio.h>
#include <lua.h> #include <lua.h>
#include <lauxlib.h> #include <lauxlib.h>
#include "lspriv.h" #include "luasocket.h"
#include "lstm.h" #include "aux.h"
#include "tm.h"
#include <stdio.h>
#ifdef WIN32 #ifdef WIN32
#include <windows.h> #include <windows.h>
@ -28,78 +29,69 @@
static int tm_lua_time(lua_State *L); static int tm_lua_time(lua_State *L);
static int tm_lua_sleep(lua_State *L); static int tm_lua_sleep(lua_State *L);
static luaL_reg func[] = {
{ "time", tm_lua_time },
{ "sleep", tm_lua_sleep },
{ NULL, NULL }
};
/*=========================================================================*\ /*=========================================================================*\
* Exported functions. * Exported functions.
\*=========================================================================*/ \*=========================================================================*/
/*-------------------------------------------------------------------------*\ /*-------------------------------------------------------------------------*\
* Sets timeout limits * Initialize structure
* Input
* tm: timeout control structure
* mode: block or return timeout
* value: timeout value in miliseconds
\*-------------------------------------------------------------------------*/ \*-------------------------------------------------------------------------*/
void tm_set(p_tm tm, int tm_block, int tm_return) void tm_init(p_tm tm, int block, int total)
{ {
tm->tm_block = tm_block; tm->block = block;
tm->tm_return = tm_return; tm->total = total;
} }
/*-------------------------------------------------------------------------*\ /*-------------------------------------------------------------------------*\
* Returns timeout limits * Set and get timeout limits
* Input
* tm: timeout control structure
* mode: block or return timeout
* value: timeout value in miliseconds
\*-------------------------------------------------------------------------*/ \*-------------------------------------------------------------------------*/
void tm_get(p_tm tm, int *tm_block, int *tm_return) void tm_setblock(p_tm tm, int block)
{ { tm->block = block; }
if (tm_block) *tm_block = tm->tm_block; void tm_settotal(p_tm tm, int total)
if (tm_return) *tm_return = tm->tm_return; { tm->total = total; }
} int tm_getblock(p_tm tm)
{ return tm->block; }
int tm_gettotal(p_tm tm)
{ return tm->total; }
int tm_getstart(p_tm tm)
{ return tm->start; }
/*-------------------------------------------------------------------------*\ /*-------------------------------------------------------------------------*\
* Determines how much time we have left for the current io operation * Determines how much time we have left for the current operation
* an IO write operation.
* Input * Input
* tm: timeout control structure * tm: timeout control structure
* Returns * Returns
* the number of ms left or -1 if there is no time limit * the number of ms left or -1 if there is no time limit
\*-------------------------------------------------------------------------*/ \*-------------------------------------------------------------------------*/
int tm_getremaining(p_tm tm) int tm_get(p_tm tm)
{ {
/* no timeout */ /* no timeout */
if (tm->tm_block < 0 && tm->tm_return < 0) if (tm->block < 0 && tm->total < 0)
return -1; return -1;
/* there is no block timeout, we use the return timeout */ /* there is no block timeout, we use the return timeout */
else if (tm->tm_block < 0) else if (tm->block < 0)
return MAX(tm->tm_return - tm_gettime() + tm->tm_start, 0); return MAX(tm->total - tm_gettime() + tm->start, 0);
/* there is no return timeout, we use the block timeout */ /* there is no return timeout, we use the block timeout */
else if (tm->tm_return < 0) else if (tm->total < 0)
return tm->tm_block; return tm->block;
/* both timeouts are specified */ /* both timeouts are specified */
else return MIN(tm->tm_block, else return MIN(tm->block,
MAX(tm->tm_return - tm_gettime() + tm->tm_start, 0)); MAX(tm->total - tm_gettime() + tm->start, 0));
} }
/*-------------------------------------------------------------------------*\ /*-------------------------------------------------------------------------*\
* Marks the operation start time in sock structure * Marks the operation start time in structure
* Input * Input
* tm: timeout control structure * tm: timeout control structure
\*-------------------------------------------------------------------------*/ \*-------------------------------------------------------------------------*/
void tm_markstart(p_tm tm) void tm_markstart(p_tm tm)
{ {
tm->tm_start = tm_gettime(); tm->start = tm_gettime();
tm->tm_end = tm->tm_start;
}
/*-------------------------------------------------------------------------*\
* Returns the length of the operation in ms
* Input
* tm: timeout control structure
\*-------------------------------------------------------------------------*/
int tm_getelapsed(p_tm tm)
{
return tm->tm_end - tm->tm_start;
} }
/*-------------------------------------------------------------------------*\ /*-------------------------------------------------------------------------*\
@ -125,11 +117,31 @@ int tm_gettime(void)
\*-------------------------------------------------------------------------*/ \*-------------------------------------------------------------------------*/
void tm_open(lua_State *L) void tm_open(lua_State *L)
{ {
(void) L; luaL_openlib(L, LUASOCKET_LIBNAME, func, 0);
lua_pushcfunction(L, tm_lua_time); }
priv_newglobal(L, "_time");
lua_pushcfunction(L, tm_lua_sleep); /*-------------------------------------------------------------------------*\
priv_newglobal(L, "_sleep"); * Sets timeout values for IO operations
* Lua Input: base, time [, mode]
* time: time out value in seconds
* mode: "b" for block timeout, "t" for total timeout. (default: b)
\*-------------------------------------------------------------------------*/
int tm_meth_timeout(lua_State *L, p_tm tm)
{
int ms = lua_isnil(L, 2) ? -1 : (int) (luaL_checknumber(L, 2)*1000.0);
const char *mode = luaL_optstring(L, 3, "b");
switch (*mode) {
case 'b':
tm_setblock(tm, ms);
break;
case 'r': case 't':
tm_settotal(tm, ms);
break;
default:
luaL_argcheck(L, 0, 3, "invalid timeout mode");
break;
}
return 0;
} }
/*=========================================================================*\ /*=========================================================================*\

View File

@ -3,23 +3,29 @@
* *
* RCS ID: $Id$ * RCS ID: $Id$
\*=========================================================================*/ \*=========================================================================*/
#ifndef _TM_H #ifndef TM_H
#define _TM_H #define TM_H
typedef struct t_tm_tag { #include <lua.h>
int tm_return;
int tm_block; /* timeout control structure */
int tm_start; typedef struct t_tm_ {
int tm_end; int total; /* total number of miliseconds for operation */
int block; /* maximum time for blocking calls */
int start; /* time of start of operation */
} t_tm; } t_tm;
typedef t_tm *p_tm; typedef t_tm *p_tm;
void tm_set(p_tm tm, int tm_block, int tm_return);
int tm_getremaining(p_tm tm);
int tm_getelapsed(p_tm tm);
int tm_gettime(void);
void tm_get(p_tm tm, int *tm_block, int *tm_return);
void tm_markstart(p_tm tm);
void tm_open(lua_State *L); void tm_open(lua_State *L);
void tm_init(p_tm tm, int block, int total);
void tm_setblock(p_tm tm, int block);
void tm_settotal(p_tm tm, int total);
int tm_getblock(p_tm tm);
int tm_gettotal(p_tm tm);
void tm_markstart(p_tm tm);
int tm_getstart(p_tm tm);
int tm_get(p_tm tm);
int tm_gettime(void);
int tm_meth_timeout(lua_State *L, p_tm tm);
#endif #endif

438
src/udp.c
View File

@ -1,15 +1,5 @@
/*=========================================================================*\ /*=========================================================================*\
* UDP class: inherits from Socked and Internet domain classes and provides * UDP object
* all the functionality for UDP objects.
* Lua methods:
* send: using compat module
* sendto: using compat module
* receive: using compat module
* receivefrom: using compat module
* setpeername: using internet module
* setsockname: using internet module
* Global Lua functions:
* udp: creates the udp object
* *
* RCS ID: $Id$ * RCS ID: $Id$
\*=========================================================================*/ \*=========================================================================*/
@ -18,282 +8,256 @@
#include <lua.h> #include <lua.h>
#include <lauxlib.h> #include <lauxlib.h>
#include "lsinet.h" #include "luasocket.h"
#include "lsudp.h"
#include "lscompat.h" #include "aux.h"
#include "lsselect.h" #include "inet.h"
#include "udp.h"
/*=========================================================================*\ /*=========================================================================*\
* Internal function prototypes. * Internal function prototypes
\*=========================================================================*/ \*=========================================================================*/
static int udp_lua_send(lua_State *L); static int udp_global_create(lua_State *L);
static int udp_lua_sendto(lua_State *L); static int udp_meth_send(lua_State *L);
static int udp_lua_receive(lua_State *L); static int udp_meth_sendto(lua_State *L);
static int udp_lua_receivefrom(lua_State *L); static int udp_meth_receive(lua_State *L);
static int udp_lua_setpeername(lua_State *L); static int udp_meth_receivefrom(lua_State *L);
static int udp_lua_setsockname(lua_State *L); static int udp_meth_getsockname(lua_State *L);
static int udp_meth_getpeername(lua_State *L);
static int udp_meth_setsockname(lua_State *L);
static int udp_meth_setpeername(lua_State *L);
static int udp_meth_close(lua_State *L);
static int udp_meth_timeout(lua_State *L);
static int udp_global_udp(lua_State *L); /* udp object methods */
static luaL_reg udp[] = {
static struct luaL_reg funcs[] = { {"setpeername", udp_meth_setpeername},
{"send", udp_lua_send}, {"setsockname", udp_meth_setsockname},
{"sendto", udp_lua_sendto}, {"getsockname", udp_meth_getsockname},
{"receive", udp_lua_receive}, {"getpeername", udp_meth_getpeername},
{"receivefrom", udp_lua_receivefrom}, {"send", udp_meth_send},
{"setpeername", udp_lua_setpeername}, {"sendto", udp_meth_sendto},
{"setsockname", udp_lua_setsockname}, {"receive", udp_meth_receive},
{"receivefrom", udp_meth_receivefrom},
{"timeout", udp_meth_timeout},
{"close", udp_meth_close},
{NULL, NULL}
};
/* functions in library namespace */
static luaL_reg func[] = {
{"udp", udp_global_create},
{NULL, NULL}
}; };
/*=========================================================================*\
* Exported functions
\*=========================================================================*/
/*-------------------------------------------------------------------------*\ /*-------------------------------------------------------------------------*\
* Initializes module * Initializes module
\*-------------------------------------------------------------------------*/ \*-------------------------------------------------------------------------*/
void udp_open(lua_State *L) void udp_open(lua_State *L)
{ {
unsigned int i; /* create classes */
priv_newclass(L, UDP_CLASS); aux_newclass(L, "udp{connected}", udp);
udp_inherit(L, UDP_CLASS); aux_newclass(L, "udp{unconnected}", udp);
/* declare global functions */ /* create class groups */
lua_pushcfunction(L, udp_global_udp); aux_add2group(L, "udp{connected}", "udp{any}");
priv_newglobal(L, "udp"); aux_add2group(L, "udp{unconnected}", "udp{any}");
for (i = 0; i < sizeof(funcs)/sizeof(funcs[0]); i++) /* define library functions */
priv_newglobalmethod(L, funcs[i].name); luaL_openlib(L, LUASOCKET_LIBNAME, func, 0);
/* make class selectable */ lua_pop(L, 1);
select_addclass(L, UDP_CLASS);
}
/*-------------------------------------------------------------------------*\
* Hook object methods to methods table.
\*-------------------------------------------------------------------------*/
void udp_inherit(lua_State *L, cchar *lsclass)
{
unsigned int i;
inet_inherit(L, lsclass);
for (i = 0; i < sizeof(funcs)/sizeof(funcs[0]); i++) {
lua_pushcfunction(L, funcs[i].func);
priv_setmethod(L, lsclass, funcs[i].name);
}
}
/*-------------------------------------------------------------------------*\
* Initializes socket structure
\*-------------------------------------------------------------------------*/
void udp_construct(lua_State *L, p_udp udp)
{
inet_construct(L, (p_inet) udp);
udp->udp_connected = 0;
}
/*-------------------------------------------------------------------------*\
* Creates a socket structure and initializes it. A socket object is
* left in the Lua stack.
* Returns
* pointer to allocated structure
\*-------------------------------------------------------------------------*/
p_udp udp_push(lua_State *L)
{
p_udp udp = (p_udp) lua_newuserdata(L, sizeof(t_udp));
priv_setclass(L, UDP_CLASS);
udp_construct(L, udp);
return udp;
} }
/*=========================================================================*\ /*=========================================================================*\
* Socket table constructors * Lua methods
\*=========================================================================*/ \*=========================================================================*/
/*-------------------------------------------------------------------------*\ /*-------------------------------------------------------------------------*\
* Creates a udp socket object and returns it to the Lua script. * Send data through connected udp socket
* Lua Input: [options]
* options: socket options table
* Lua Returns
* On success: udp socket
* On error: nil, followed by an error message
\*-------------------------------------------------------------------------*/ \*-------------------------------------------------------------------------*/
static int udp_global_udp(lua_State *L) static int udp_meth_send(lua_State *L)
{ {
int oldtop = lua_gettop(L); p_udp udp = (p_udp) aux_checkclass(L, "udp{connected}", 1);
p_udp udp = udp_push(L); p_tm tm = &udp->tm;
cchar *err = inet_trysocket((p_inet) udp, SOCK_DGRAM); size_t count, sent = 0;
if (err) {
lua_pushnil(L);
lua_pushstring(L, err);
return 2;
}
if (oldtop < 1) return 1;
err = compat_trysetoptions(L, udp->fd);
if (err) {
lua_pushnil(L);
lua_pushstring(L, err);
return 2;
}
return 1;
}
/*=========================================================================*\
* Socket table methods
\*=========================================================================*/
/*-------------------------------------------------------------------------*\
* Receives data from a UDP socket
* Lua Input: sock [, wanted]
* sock: client socket created by the connect function
* wanted: the number of bytes expected (default: LUASOCKET_UDPBUFFERSIZE)
* Lua Returns
* On success: datagram received
* On error: nil, followed by an error message
\*-------------------------------------------------------------------------*/
static int udp_lua_receive(lua_State *L)
{
p_udp udp = (p_udp) lua_touserdata(L, 1);
char buffer[UDP_DATAGRAMSIZE];
size_t got, wanted = (size_t) luaL_optnumber(L, 2, sizeof(buffer));
int err; int err;
p_tm tm = &udp->base_tm; const char *data = luaL_checklstring(L, 2, &count);
wanted = MIN(wanted, sizeof(buffer));
tm_markstart(tm); tm_markstart(tm);
err = compat_recv(udp->fd, buffer, wanted, &got, tm_getremaining(tm)); err = sock_send(&udp->sock, data, count, &sent, tm_get(tm));
if (err == PRIV_CLOSED) err = PRIV_REFUSED; if (err == IO_DONE) lua_pushnumber(L, sent);
if (err != PRIV_DONE) lua_pushnil(L); else lua_pushnil(L);
else lua_pushlstring(L, buffer, got); error_push(L, err);
priv_pusherror(L, err);
return 2; return 2;
} }
/*-------------------------------------------------------------------------*\ /*-------------------------------------------------------------------------*\
* Receives a datagram from a UDP socket * Send data through unconnected udp socket
* Lua Input: sock [, wanted]
* sock: client socket created by the connect function
* wanted: the number of bytes expected (default: LUASOCKET_UDPBUFFERSIZE)
* Lua Returns
* On success: datagram received, ip and port of sender
* On error: nil, followed by an error message
\*-------------------------------------------------------------------------*/ \*-------------------------------------------------------------------------*/
static int udp_lua_receivefrom(lua_State *L) static int udp_meth_sendto(lua_State *L)
{ {
p_udp udp = (p_udp) lua_touserdata(L, 1); p_udp udp = (p_udp) aux_checkclass(L, "udp{unconnected}", 1);
p_tm tm = &udp->base_tm; size_t count, sent = 0;
struct sockaddr_in peer; const char *data = luaL_checklstring(L, 2, &count);
size_t peer_len = sizeof(peer); const char *ip = luaL_checkstring(L, 3);
char buffer[UDP_DATAGRAMSIZE]; ushort port = (ushort) luaL_checknumber(L, 4);
size_t wanted = (size_t) luaL_optnumber(L, 2, sizeof(buffer)); p_tm tm = &udp->tm;
size_t got; struct sockaddr_in addr;
int err; int err;
if (udp->udp_connected) luaL_error(L, "receivefrom on connected socket"); memset(&addr, 0, sizeof(addr));
if (!inet_aton(ip, &addr.sin_addr))
luaL_argerror(L, 3, "invalid ip address");
addr.sin_family = AF_INET;
addr.sin_port = htons(port);
tm_markstart(tm); tm_markstart(tm);
wanted = MIN(wanted, sizeof(buffer)); err = sock_sendto(&udp->sock, data, count, &sent,
err = compat_recvfrom(udp->fd, buffer, wanted, &got, tm_getremaining(tm), (SA *) &addr, sizeof(addr), tm_get(tm));
(SA *) &peer, &peer_len); if (err == IO_DONE) lua_pushnumber(L, sent);
if (err == PRIV_CLOSED) err = PRIV_REFUSED; else lua_pushnil(L);
if (err == PRIV_DONE) { error_push(L, err == IO_CLOSED ? IO_REFUSED : err);
return 2;
}
/*-------------------------------------------------------------------------*\
* Receives data from a UDP socket
\*-------------------------------------------------------------------------*/
static int udp_meth_receive(lua_State *L)
{
p_udp udp = (p_udp) aux_checkgroup(L, "udp{any}", 1);
char buffer[UDP_DATAGRAMSIZE];
size_t got, count = (size_t) luaL_optnumber(L, 2, sizeof(buffer));
int err;
p_tm tm = &udp->tm;
count = MIN(count, sizeof(buffer));
tm_markstart(tm);
err = sock_recv(&udp->sock, buffer, count, &got, tm_get(tm));
if (err == IO_DONE) lua_pushlstring(L, buffer, got);
else lua_pushnil(L);
error_push(L, err);
return 2;
}
/*-------------------------------------------------------------------------*\
* Receives data and sender from a UDP socket
\*-------------------------------------------------------------------------*/
static int udp_meth_receivefrom(lua_State *L)
{
p_udp udp = (p_udp) aux_checkclass(L, "udp{unconnected}", 1);
struct sockaddr_in addr;
size_t addr_len = sizeof(addr);
char buffer[UDP_DATAGRAMSIZE];
size_t got, count = (size_t) luaL_optnumber(L, 2, sizeof(buffer));
int err;
p_tm tm = &udp->tm;
tm_markstart(tm);
count = MIN(count, sizeof(buffer));
err = sock_recvfrom(&udp->sock, buffer, count, &got,
(SA *) &addr, &addr_len, tm_get(tm));
if (err == IO_DONE) {
lua_pushlstring(L, buffer, got); lua_pushlstring(L, buffer, got);
lua_pushstring(L, inet_ntoa(peer.sin_addr)); lua_pushstring(L, inet_ntoa(addr.sin_addr));
lua_pushnumber(L, ntohs(peer.sin_port)); lua_pushnumber(L, ntohs(addr.sin_port));
return 3; return 3;
} else { } else {
lua_pushnil(L); lua_pushnil(L);
priv_pusherror(L, err); error_push(L, err);
return 2; return 2;
} }
} }
/*-------------------------------------------------------------------------*\ /*-------------------------------------------------------------------------*\
* Send data through a connected UDP socket * Just call inet methods
* Lua Input: sock, data
* sock: udp socket
* data: data to be sent
* Lua Returns
* On success: nil, followed by the total number of bytes sent
* On error: error message
\*-------------------------------------------------------------------------*/ \*-------------------------------------------------------------------------*/
static int udp_lua_send(lua_State *L) static int udp_meth_getpeername(lua_State *L)
{ {
p_udp udp = (p_udp) lua_touserdata(L, 1); p_udp udp = (p_udp) aux_checkclass(L, "udp{connected}", 1);
p_tm tm = &udp->base_tm; return inet_meth_getpeername(L, &udp->sock);
size_t wanted, sent = 0; }
int err;
cchar *data = luaL_checklstring(L, 2, &wanted); static int udp_meth_getsockname(lua_State *L)
if (!udp->udp_connected) luaL_error(L, "send on unconnected socket"); {
tm_markstart(tm); p_udp udp = (p_udp) aux_checkgroup(L, "udp{any}", 1);
err = compat_send(udp->fd, data, wanted, &sent, tm_getremaining(tm)); return inet_meth_getsockname(L, &udp->sock);
priv_pusherror(L, err == PRIV_CLOSED ? PRIV_REFUSED : err);
lua_pushnumber(L, sent);
return 2;
} }
/*-------------------------------------------------------------------------*\ /*-------------------------------------------------------------------------*\
* Send data through a unconnected UDP socket * Just call tm methods
* Lua Input: sock, data, ip, port
* sock: udp socket
* data: data to be sent
* ip: ip address of target
* port: port in target
* Lua Returns
* On success: nil, followed by the total number of bytes sent
* On error: error message
\*-------------------------------------------------------------------------*/ \*-------------------------------------------------------------------------*/
static int udp_lua_sendto(lua_State *L) static int udp_meth_timeout(lua_State *L)
{ {
p_udp udp = (p_udp) lua_touserdata(L, 1); p_udp udp = (p_udp) aux_checkgroup(L, "udp{any}", 1);
size_t wanted, sent = 0; return tm_meth_timeout(L, &udp->tm);
cchar *data = luaL_checklstring(L, 2, &wanted);
cchar *ip = luaL_checkstring(L, 3);
ushort port = (ushort) luaL_checknumber(L, 4);
p_tm tm = &udp->base_tm;
struct sockaddr_in peer;
int err;
if (udp->udp_connected) luaL_error(L, "sendto on connected socket");
memset(&peer, 0, sizeof(peer));
if (!inet_aton(ip, &peer.sin_addr)) luaL_error(L, "invalid ip address");
peer.sin_family = AF_INET;
peer.sin_port = htons(port);
tm_markstart(tm);
err = compat_sendto(udp->fd, data, wanted, &sent, tm_getremaining(tm),
(SA *) &peer, sizeof(peer));
priv_pusherror(L, err == PRIV_CLOSED ? PRIV_REFUSED : err);
lua_pushnumber(L, sent);
return 2;
} }
/*-------------------------------------------------------------------------*\ /*-------------------------------------------------------------------------*\
* Associates a local address to an UDP socket * Turns a master udp object into a client object.
* Lua Input: address, port
* address: host name or ip address to bind to
* port: port to bind to
* Lua Returns
* On success: nil
* On error: error message
\*-------------------------------------------------------------------------*/ \*-------------------------------------------------------------------------*/
static int udp_lua_setsockname(lua_State * L) static int udp_meth_setpeername(lua_State *L)
{ {
p_udp udp = (p_udp) lua_touserdata(L, 1); p_udp udp = (p_udp) aux_checkclass(L, "udp{unconnected}", 1);
cchar *address = luaL_checkstring(L, 2); const char *address = luaL_checkstring(L, 2);
ushort port = (ushort) luaL_checknumber(L, 3); int connecting = strcmp(address, "*");
cchar *err = inet_trybind((p_inet) udp, address, port); unsigned short port = connecting ?
if (err) lua_pushstring(L, err); (ushort) luaL_checknumber(L, 3) : (ushort) luaL_optnumber(L, 3, 0);
else lua_pushnil(L); const char *err = inet_tryconnect(&udp->sock, address, port);
return 1; if (err) {
}
/*-------------------------------------------------------------------------*\
* Sets a peer for a UDP socket
* Lua Input: address, port
* address: remote host name
* port: remote host port
* Lua Returns
* On success: nil
* On error: error message
\*-------------------------------------------------------------------------*/
static int udp_lua_setpeername(lua_State *L)
{
p_udp udp = (p_udp) lua_touserdata(L, 1);
cchar *address = luaL_checkstring(L, 2);
ushort port = (ushort) luaL_checknumber(L, 3);
cchar *err = inet_tryconnect((p_inet) udp, address, port);
if (!err) {
udp->udp_connected = 1;
lua_pushnil(L); lua_pushnil(L);
} else lua_pushstring(L, err); lua_pushstring(L, err);
return 2;
}
/* change class to connected or unconnected depending on address */
if (connecting) aux_setclass(L, "udp{connected}", 1);
else aux_setclass(L, "udp{unconnected}", 1);
lua_pushnumber(L, 1);
return 1; return 1;
} }
/*-------------------------------------------------------------------------*\
* Closes socket used by object
\*-------------------------------------------------------------------------*/
static int udp_meth_close(lua_State *L)
{
p_udp udp = (p_udp) aux_checkgroup(L, "udp{any}", 1);
sock_destroy(&udp->sock);
return 0;
}
/*-------------------------------------------------------------------------*\
* Turns a master object into a server object
\*-------------------------------------------------------------------------*/
static int udp_meth_setsockname(lua_State *L)
{
p_udp udp = (p_udp) aux_checkclass(L, "udp{master}", 1);
const char *address = luaL_checkstring(L, 2);
unsigned short port = (ushort) luaL_checknumber(L, 3);
const char *err = inet_trybind(&udp->sock, address, port, -1);
if (err) {
lua_pushnil(L);
lua_pushstring(L, err);
return 2;
}
lua_pushnumber(L, 1);
return 1;
}
/*=========================================================================*\
* Library functions
\*=========================================================================*/
/*-------------------------------------------------------------------------*\
* Creates a master udp object
\*-------------------------------------------------------------------------*/
int udp_global_create(lua_State *L)
{
/* allocate udp object */
p_udp udp = (p_udp) lua_newuserdata(L, sizeof(t_udp));
/* set its type as master object */
aux_setclass(L, "udp{unconnected}", -1);
/* try to allocate a system socket */
const char *err = inet_trycreate(&udp->sock, SOCK_DGRAM);
if (err) {
/* get rid of object on stack and push error */
lua_pop(L, 1);
lua_pushnil(L);
lua_pushstring(L, err);
return 2;
}
/* initialize timeout management */
tm_init(&udp->tm, -1, -1);
return 1;
}

View File

@ -1,30 +1,19 @@
/*=========================================================================*\ #ifndef UDP_H
* UDP class: inherits from Socked and Internet domain classes and provides #define UDP_H
* all the functionality for UDP objects.
*
* RCS ID: $Id$
\*=========================================================================*/
#ifndef UDP_H_
#define UDP_H_
#include "lsinet.h" #include <lua.h>
#define UDP_CLASS "luasocket(UDP socket)" #include "tm.h"
#include "sock.h"
#define UDP_DATAGRAMSIZE 576 #define UDP_DATAGRAMSIZE 576
#define UDP_FIELDS \ typedef struct t_udp_ {
INET_FIELDS; \ t_sock sock;
int udp_connected t_tm tm;
typedef struct t_udp_tag {
UDP_FIELDS;
} t_udp; } t_udp;
typedef t_udp *p_udp; typedef t_udp *p_udp;
void udp_inherit(lua_State *L, cchar *lsclass);
void udp_construct(lua_State *L, p_udp udp);
void udp_open(lua_State *L); void udp_open(lua_State *L);
p_udp udp_push(lua_State *L);
#endif #endif

View File

@ -1,5 +1,5 @@
/*=========================================================================*\ /*=========================================================================*\
* Network compatibilization module: Unix version * Socket compatibilization module for Unix
* *
* RCS ID: $Id$ * RCS ID: $Id$
\*=========================================================================*/ \*=========================================================================*/
@ -7,20 +7,20 @@
#include <lauxlib.h> #include <lauxlib.h>
#include <string.h> #include <string.h>
#include "lscompat.h" #include "sock.h"
/*=========================================================================*\ /*=========================================================================*\
* Internal function prototypes * Internal function prototypes
\*=========================================================================*/ \*=========================================================================*/
static cchar *try_setoption(lua_State *L, COMPAT_FD sock); static const char *try_setoption(lua_State *L, p_sock ps);
static cchar *try_setbooloption(lua_State *L, COMPAT_FD sock, int name); static const char *try_setbooloption(lua_State *L, p_sock ps, int name);
/*=========================================================================*\ /*=========================================================================*\
* Exported functions. * Exported functions.
\*=========================================================================*/ \*=========================================================================*/
int compat_open(lua_State *L) int sock_open(lua_State *L)
{ {
/* Instals a handler to ignore sigpipe. */ /* instals a handler to ignore sigpipe. */
struct sigaction new; struct sigaction new;
memset(&new, 0, sizeof(new)); memset(&new, 0, sizeof(new));
new.sa_handler = SIG_IGN; new.sa_handler = SIG_IGN;
@ -28,143 +28,178 @@ int compat_open(lua_State *L)
return 1; return 1;
} }
COMPAT_FD compat_accept(COMPAT_FD s, struct sockaddr *addr, void sock_destroy(p_sock ps)
size_t *len, int deadline)
{ {
struct timeval tv; close(*ps);
fd_set fds;
tv.tv_sec = deadline / 1000;
tv.tv_usec = (deadline % 1000) * 1000;
FD_ZERO(&fds);
FD_SET(s, &fds);
select(s+1, &fds, NULL, NULL, deadline >= 0 ? &tv : NULL);
return accept(s, addr, len);
} }
int compat_send(COMPAT_FD c, cchar *data, size_t count, size_t *sent, const char *sock_create(p_sock ps, int domain, int type, int protocol)
int deadline)
{ {
t_sock sock = socket(domain, type, protocol);
if (sock == SOCK_INVALID) return sock_createstrerror();
*ps = sock;
sock_setnonblocking(ps);
sock_setreuseaddr(ps);
return NULL;
}
const char *sock_connect(p_sock ps, SA *addr, size_t addr_len)
{
if (connect(*ps, addr, addr_len) < 0) return sock_connectstrerror();
else return NULL;
}
const char *sock_bind(p_sock ps, SA *addr, size_t addr_len)
{
if (bind(*ps, addr, addr_len) < 0) return sock_bindstrerror();
else return NULL;
}
void sock_listen(p_sock ps, int backlog)
{
listen(*ps, backlog);
}
void sock_accept(p_sock ps, p_sock pa, SA *addr, size_t *addr_len, int timeout)
{
t_sock sock = *ps;
struct timeval tv;
fd_set fds;
tv.tv_sec = timeout / 1000;
tv.tv_usec = (timeout % 1000) * 1000;
FD_ZERO(&fds);
FD_SET(sock, &fds);
select(sock+1, &fds, NULL, NULL, timeout >= 0 ? &tv : NULL);
*pa = accept(sock, addr, addr_len);
}
int sock_send(p_sock ps, const char *data, size_t count, size_t *sent,
int timeout)
{
t_sock sock = *ps;
struct timeval tv; struct timeval tv;
fd_set fds; fd_set fds;
ssize_t put = 0; ssize_t put = 0;
int err; int err;
int ret; int ret;
tv.tv_sec = deadline / 1000; tv.tv_sec = timeout / 1000;
tv.tv_usec = (deadline % 1000) * 1000; tv.tv_usec = (timeout % 1000) * 1000;
FD_ZERO(&fds); FD_ZERO(&fds);
FD_SET(c, &fds); FD_SET(sock, &fds);
ret = select(c+1, NULL, &fds, NULL, deadline >= 0 ? &tv : NULL); ret = select(sock+1, NULL, &fds, NULL, timeout >= 0 ? &tv : NULL);
if (ret > 0) { if (ret > 0) {
put = write(c, data, count); put = write(sock, data, count);
if (put <= 0) { if (put <= 0) {
err = PRIV_CLOSED; err = IO_CLOSED;
#ifdef __CYGWIN__ #ifdef __CYGWIN__
/* this is for CYGWIN, which is like Unix but has Win32 bugs */ /* this is for CYGWIN, which is like Unix but has Win32 bugs */
if (errno == EWOULDBLOCK) err = PRIV_DONE; if (errno == EWOULDBLOCK) err = IO_DONE;
#endif #endif
*sent = 0; *sent = 0;
} else { } else {
*sent = put; *sent = put;
err = PRIV_DONE; err = IO_DONE;
} }
return err; return err;
} else { } else {
*sent = 0; *sent = 0;
return PRIV_TIMEOUT; return IO_TIMEOUT;
} }
} }
int compat_sendto(COMPAT_FD c, cchar *data, size_t count, size_t *sent, int sock_sendto(p_sock ps, const char *data, size_t count, size_t *sent,
int deadline, SA *addr, size_t len) SA *addr, size_t addr_len, int timeout)
{ {
t_sock sock = *ps;
struct timeval tv; struct timeval tv;
fd_set fds; fd_set fds;
ssize_t put = 0; ssize_t put = 0;
int err; int err;
int ret; int ret;
tv.tv_sec = deadline / 1000; tv.tv_sec = timeout / 1000;
tv.tv_usec = (deadline % 1000) * 1000; tv.tv_usec = (timeout % 1000) * 1000;
FD_ZERO(&fds); FD_ZERO(&fds);
FD_SET(c, &fds); FD_SET(sock, &fds);
ret = select(c+1, NULL, &fds, NULL, deadline >= 0 ? &tv : NULL); ret = select(sock+1, NULL, &fds, NULL, timeout >= 0 ? &tv : NULL);
if (ret > 0) { if (ret > 0) {
put = sendto(c, data, count, 0, addr, len); put = sendto(sock, data, count, 0, addr, addr_len);
if (put <= 0) { if (put <= 0) {
err = PRIV_CLOSED; err = IO_CLOSED;
#ifdef __CYGWIN__ #ifdef __CYGWIN__
/* this is for CYGWIN, which is like Unix but has Win32 bugs */ /* this is for CYGWIN, which is like Unix but has Win32 bugs */
if (sent < 0 && errno == EWOULDBLOCK) err = PRIV_DONE; if (sent < 0 && errno == EWOULDBLOCK) err = IO_DONE;
#endif #endif
*sent = 0; *sent = 0;
} else { } else {
*sent = put; *sent = put;
err = PRIV_DONE; err = IO_DONE;
} }
return err; return err;
} else { } else {
*sent = 0; *sent = 0;
return PRIV_TIMEOUT; return IO_TIMEOUT;
} }
} }
int compat_recv(COMPAT_FD c, char *data, size_t count, size_t *got, int sock_recv(p_sock ps, char *data, size_t count, size_t *got, int timeout)
int deadline)
{ {
t_sock sock = *ps;
struct timeval tv; struct timeval tv;
fd_set fds; fd_set fds;
int ret; int ret;
ssize_t taken = 0; ssize_t taken = 0;
tv.tv_sec = deadline / 1000; tv.tv_sec = timeout / 1000;
tv.tv_usec = (deadline % 1000) * 1000; tv.tv_usec = (timeout % 1000) * 1000;
FD_ZERO(&fds); FD_ZERO(&fds);
FD_SET(c, &fds); FD_SET(sock, &fds);
ret = select(c+1, &fds, NULL, NULL, deadline >= 0 ? &tv : NULL); ret = select(sock+1, &fds, NULL, NULL, timeout >= 0 ? &tv : NULL);
if (ret > 0) { if (ret > 0) {
taken = read(c, data, count); taken = read(sock, data, count);
if (taken <= 0) { if (taken <= 0) {
*got = 0; *got = 0;
return PRIV_CLOSED; return IO_CLOSED;
} else { } else {
*got = taken; *got = taken;
return PRIV_DONE; return IO_DONE;
} }
} else { } else {
*got = 0; *got = 0;
return PRIV_TIMEOUT; return IO_TIMEOUT;
} }
} }
int compat_recvfrom(COMPAT_FD c, char *data, size_t count, size_t *got, int sock_recvfrom(p_sock ps, char *data, size_t count, size_t *got,
int deadline, SA *addr, size_t *len) SA *addr, size_t *addr_len, int timeout)
{ {
t_sock sock = *ps;
struct timeval tv; struct timeval tv;
fd_set fds; fd_set fds;
int ret; int ret;
ssize_t taken = 0; ssize_t taken = 0;
tv.tv_sec = deadline / 1000; tv.tv_sec = timeout / 1000;
tv.tv_usec = (deadline % 1000) * 1000; tv.tv_usec = (timeout % 1000) * 1000;
FD_ZERO(&fds); FD_ZERO(&fds);
FD_SET(c, &fds); FD_SET(sock, &fds);
ret = select(c+1, &fds, NULL, NULL, deadline >= 0 ? &tv : NULL); ret = select(sock+1, &fds, NULL, NULL, timeout >= 0 ? &tv : NULL);
if (ret > 0) { if (ret > 0) {
taken = recvfrom(c, data, count, 0, addr, len); taken = recvfrom(sock, data, count, 0, addr, addr_len);
if (taken <= 0) { if (taken <= 0) {
*got = 0; *got = 0;
return PRIV_CLOSED; return IO_CLOSED;
} else { } else {
*got = taken; *got = taken;
return PRIV_DONE; return IO_DONE;
} }
} else { } else {
*got = 0; *got = 0;
return PRIV_TIMEOUT; return IO_TIMEOUT;
} }
} }
/*-------------------------------------------------------------------------*\ /*-------------------------------------------------------------------------*\
* Returns a string describing the last host manipulation error. * Returns a string describing the last host manipulation error.
\*-------------------------------------------------------------------------*/ \*-------------------------------------------------------------------------*/
const char *compat_hoststrerror(void) const char *sock_hoststrerror(void)
{ {
switch (h_errno) { switch (h_errno) {
case HOST_NOT_FOUND: return "host not found"; case HOST_NOT_FOUND: return "host not found";
@ -178,7 +213,7 @@ const char *compat_hoststrerror(void)
/*-------------------------------------------------------------------------*\ /*-------------------------------------------------------------------------*\
* Returns a string describing the last socket manipulation error. * Returns a string describing the last socket manipulation error.
\*-------------------------------------------------------------------------*/ \*-------------------------------------------------------------------------*/
const char *compat_socketstrerror(void) const char *sock_createstrerror(void)
{ {
switch (errno) { switch (errno) {
case EACCES: return "access denied"; case EACCES: return "access denied";
@ -192,7 +227,7 @@ const char *compat_socketstrerror(void)
/*-------------------------------------------------------------------------*\ /*-------------------------------------------------------------------------*\
* Returns a string describing the last bind command error. * Returns a string describing the last bind command error.
\*-------------------------------------------------------------------------*/ \*-------------------------------------------------------------------------*/
const char *compat_bindstrerror(void) const char *sock_bindstrerror(void)
{ {
switch (errno) { switch (errno) {
case EBADF: return "invalid descriptor"; case EBADF: return "invalid descriptor";
@ -209,7 +244,7 @@ const char *compat_bindstrerror(void)
/*-------------------------------------------------------------------------*\ /*-------------------------------------------------------------------------*\
* Returns a string describing the last connect error. * Returns a string describing the last connect error.
\*-------------------------------------------------------------------------*/ \*-------------------------------------------------------------------------*/
const char *compat_connectstrerror(void) const char *sock_connectstrerror(void)
{ {
switch (errno) { switch (errno) {
case EBADF: return "invalid descriptor"; case EBADF: return "invalid descriptor";
@ -229,40 +264,30 @@ const char *compat_connectstrerror(void)
* Input * Input
* sock: socket descriptor * sock: socket descriptor
\*-------------------------------------------------------------------------*/ \*-------------------------------------------------------------------------*/
void compat_setreuseaddr(COMPAT_FD sock) void sock_setreuseaddr(p_sock ps)
{ {
int val = 1; int val = 1;
setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, (char *)&val, sizeof(val)); setsockopt(*ps, SOL_SOCKET, SO_REUSEADDR, (char *)&val, sizeof(val));
}
COMPAT_FD compat_socket(int domain, int type, int protocol)
{
COMPAT_FD sock = socket(domain, type, protocol);
if (sock != COMPAT_INVALIDFD) {
compat_setnonblocking(sock);
compat_setreuseaddr(sock);
}
return sock;
} }
/*-------------------------------------------------------------------------*\ /*-------------------------------------------------------------------------*\
* Put socket into blocking mode. * Put socket into blocking mode.
\*-------------------------------------------------------------------------*/ \*-------------------------------------------------------------------------*/
void compat_setblocking(COMPAT_FD sock) void sock_setblocking(p_sock ps)
{ {
int flags = fcntl(sock, F_GETFL, 0); int flags = fcntl(*ps, F_GETFL, 0);
flags &= (~(O_NONBLOCK)); flags &= (~(O_NONBLOCK));
fcntl(sock, F_SETFL, flags); fcntl(*ps, F_SETFL, flags);
} }
/*-------------------------------------------------------------------------*\ /*-------------------------------------------------------------------------*\
* Put socket into non-blocking mode. * Put socket into non-blocking mode.
\*-------------------------------------------------------------------------*/ \*-------------------------------------------------------------------------*/
void compat_setnonblocking(COMPAT_FD sock) void sock_setnonblocking(p_sock ps)
{ {
int flags = fcntl(sock, F_GETFL, 0); int flags = fcntl(*ps, F_GETFL, 0);
flags |= O_NONBLOCK; flags |= O_NONBLOCK;
fcntl(sock, F_SETFL, flags); fcntl(*ps, F_SETFL, flags);
} }
/*-------------------------------------------------------------------------*\ /*-------------------------------------------------------------------------*\
@ -273,54 +298,50 @@ void compat_setnonblocking(COMPAT_FD sock)
* Returns * Returns
* NULL if successfull, error message on error * NULL if successfull, error message on error
\*-------------------------------------------------------------------------*/ \*-------------------------------------------------------------------------*/
cchar *compat_trysetoptions(lua_State *L, COMPAT_FD sock) const char *sock_trysetoptions(lua_State *L, p_sock ps)
{ {
if (!lua_istable(L, 1)) luaL_argerror(L, 1, "invalid options table"); if (!lua_istable(L, 1)) luaL_argerror(L, 1, "invalid options table");
lua_pushnil(L); lua_pushnil(L);
while (lua_next(L, 1)) { while (lua_next(L, 1)) {
cchar *err = try_setoption(L, sock); const char *err = try_setoption(L, ps);
lua_pop(L, 1); lua_pop(L, 1);
if (err) return err; if (err) return err;
} }
return NULL; return NULL;
} }
/*=========================================================================*\
* Internal functions.
\*=========================================================================*/
static cchar *try_setbooloption(lua_State *L, COMPAT_FD sock, int name)
{
int bool, res;
if (!lua_isnumber(L, -1)) luaL_error(L, "invalid option value");
bool = (int) lua_tonumber(L, -1);
res = setsockopt(sock, SOL_SOCKET, name, (char *) &bool, sizeof(bool));
if (res < 0) return "error setting option";
else return NULL;
}
/*-------------------------------------------------------------------------*\ /*-------------------------------------------------------------------------*\
* Set socket options from a table on top of Lua stack. * Set socket options from a table on top of Lua stack.
* Supports SO_KEEPALIVE, SO_DONTROUTE, SO_BROADCAST, and SO_LINGER options. * Supports SO_KEEPALIVE, SO_DONTROUTE, and SO_BROADCAST options.
* Input * Input
* L: Lua state to use * sock: socket
* sock: socket descriptor
* Returns * Returns
* 1 if successful, 0 otherwise * 1 if successful, 0 otherwise
\*-------------------------------------------------------------------------*/ \*-------------------------------------------------------------------------*/
static cchar *try_setoption(lua_State *L, COMPAT_FD sock) static const char *try_setoption(lua_State *L, p_sock ps)
{ {
static cchar *options[] = { static const char *options[] = {
"SO_KEEPALIVE", "SO_DONTROUTE", "SO_BROADCAST", "SO_LINGER", NULL "SO_KEEPALIVE", "SO_DONTROUTE", "SO_BROADCAST", NULL
}; };
cchar *option = lua_tostring(L, -2); const char *option = lua_tostring(L, -2);
if (!lua_isstring(L, -2)) return "invalid option"; if (!lua_isstring(L, -2)) return "invalid option";
switch (luaL_findstring(option, options)) { switch (luaL_findstring(option, options)) {
case 0: return try_setbooloption(L, sock, SO_KEEPALIVE); case 0: return try_setbooloption(L, ps, SO_KEEPALIVE);
case 1: return try_setbooloption(L, sock, SO_DONTROUTE); case 1: return try_setbooloption(L, ps, SO_DONTROUTE);
case 2: return try_setbooloption(L, sock, SO_BROADCAST); case 2: return try_setbooloption(L, ps, SO_BROADCAST);
case 3: return "SO_LINGER is deprecated";
default: return "unsupported option"; default: return "unsupported option";
} }
} }
/*=========================================================================*\
* Internal functions.
\*=========================================================================*/
static const char *try_setbooloption(lua_State *L, p_sock ps, int name)
{
int bool, res;
if (!lua_isnumber(L, -1)) luaL_error(L, "invalid option value");
bool = (int) lua_tonumber(L, -1);
res = setsockopt(*ps, SOL_SOCKET, name, (char *) &bool, sizeof(bool));
if (res < 0) return "error setting option";
else return NULL;
}

View File

@ -1,10 +1,10 @@
/*=========================================================================*\ /*=========================================================================*\
* Network compatibilization module: Unix version * Socket compatibilization module for Unix
* *
* RCS ID: $Id$ * RCS ID: $Id$
\*=========================================================================*/ \*=========================================================================*/
#ifndef UNIX_H_ #ifndef UNIX_H
#define UNIX_H_ #define UNIX_H
/*=========================================================================*\ /*=========================================================================*\
* BSD include files * BSD include files
@ -31,13 +31,9 @@
#include <netinet/in.h> #include <netinet/in.h>
#include <arpa/inet.h> #include <arpa/inet.h>
#define COMPAT_FD int typedef int t_sock;
#define COMPAT_INVALIDFD (-1) typedef t_sock *p_sock;
#define compat_bind bind #define SOCK_INVALID (-1)
#define compat_connect connect
#define compat_listen listen
#define compat_close close
#define compat_select select
#endif /* UNIX_H_ */ #endif /* UNIX_H */

View File

@ -1,5 +1,3 @@
dofile("noglobals.lua")
local similar = function(s1, s2) local similar = function(s1, s2)
return return
string.lower(string.gsub(s1, "%s", "")) == string.lower(string.gsub(s1, "%s", "")) ==
@ -34,7 +32,7 @@ end
local index, err, saved, back, expected local index, err, saved, back, expected
local t = socket._time() local t = socket.time()
index = readfile("test/index.html") index = readfile("test/index.html")
@ -112,4 +110,4 @@ back, err = socket.ftp.get("ftp://localhost/index.wrong.html;type=a")
check(err, err) check(err, err)
print("passed all tests") print("passed all tests")
print(string.format("done in %.2fs", socket._time() - t)) print(string.format("done in %.2fs", socket.time() - t))

View File

@ -3,9 +3,6 @@
-- needs ScriptAlias from /home/c/diego/tec/luasocket/test/cgi -- needs ScriptAlias from /home/c/diego/tec/luasocket/test/cgi
-- to /luasocket-test-cgi -- to /luasocket-test-cgi
-- needs AllowOverride AuthConfig on /home/c/diego/tec/luasocket/test/auth -- needs AllowOverride AuthConfig on /home/c/diego/tec/luasocket/test/auth
dofile("noglobals.lua")
local similar = function(s1, s2) local similar = function(s1, s2)
return string.lower(string.gsub(s1 or "", "%s", "")) == return string.lower(string.gsub(s1 or "", "%s", "")) ==
string.lower(string.gsub(s2 or "", "%s", "")) string.lower(string.gsub(s2 or "", "%s", ""))
@ -27,27 +24,27 @@ end
local check = function (v, e) local check = function (v, e)
if v then print("ok") if v then print("ok")
else %fail(e) end else fail(e) end
end end
local check_request = function(request, expect, ignore) local check_request = function(request, expect, ignore)
local response = socket.http.request(request) local response = socket.http.request(request)
for i,v in response do for i,v in response do
if not ignore[i] then if not ignore[i] then
if v ~= expect[i] then %fail(i .. " differs!") end if v ~= expect[i] then fail(i .. " differs!") end
end end
end end
for i,v in expect do for i,v in expect do
if not ignore[i] then if not ignore[i] then
if v ~= response[i] then %fail(i .. " differs!") end if v ~= response[i] then fail(i .. " differs!") end
end end
end end
print("ok") print("ok")
end end
local request, response, ignore, expect, index, prefix, cgiprefix local host, request, response, ignore, expect, index, prefix, cgiprefix
local t = socket._time() local t = socket.time()
host = host or "localhost" host = host or "localhost"
prefix = prefix or "/luasocket" prefix = prefix or "/luasocket"
@ -310,4 +307,4 @@ check(response and response.headers)
print("passed all tests") print("passed all tests")
print(string.format("done in %.2fs", socket._time() - t)) print(string.format("done in %.2fs", socket.time() - t))

View File

@ -11,7 +11,7 @@ local files = {
"/var/spool/mail/luasock3", "/var/spool/mail/luasock3",
} }
local t = socket._time() local t = socket.time()
local err local err
dofile("mbox.lua") dofile("mbox.lua")
@ -106,7 +106,7 @@ local insert = function(sent, message)
end end
local mark = function() local mark = function()
local time = socket._time() local time = socket.time()
return { time = time } return { time = time }
end end
@ -116,11 +116,11 @@ local wait = function(sentinel, n)
while 1 do while 1 do
local mbox = parse(get()) local mbox = parse(get())
if n == table.getn(mbox) then break end if n == table.getn(mbox) then break end
if socket._time() - sentinel.time > 50 then if socket.time() - sentinel.time > 50 then
to = 1 to = 1
break break
end end
socket._sleep(1) socket.sleep(1)
io.write(".") io.write(".")
io.stdout:flush() io.stdout:flush()
end end
@ -256,4 +256,4 @@ for i = 1, table.getn(mbox) do
end end
print("passed all tests") print("passed all tests")
print(string.format("done in %.2fs", socket._time() - t)) print(string.format("done in %.2fs", socket.time() - t))

View File

@ -43,7 +43,7 @@ function check_timeout(tm, sl, elapsed, err, opp, mode, alldone)
else pass("proper timeout") end else pass("proper timeout") end
end end
else else
if mode == "return" then if mode == "total" then
if elapsed > tm then if elapsed > tm then
if err ~= "timeout" then fail("should have timed out") if err ~= "timeout" then fail("should have timed out")
else pass("proper timeout") end else pass("proper timeout") end
@ -66,17 +66,17 @@ function check_timeout(tm, sl, elapsed, err, opp, mode, alldone)
end end
end end
if not socket.debug then
fail("Please define LUASOCKET_DEBUG and recompile LuaSocket")
end
io.write("----------------------------------------------\n", io.write("----------------------------------------------\n",
"LuaSocket Test Procedures\n", "LuaSocket Test Procedures\n",
"----------------------------------------------\n") "----------------------------------------------\n")
if not socket._time or not socket._sleep then start = socket.time()
fail("not compiled with _DEBUG")
end
start = socket._time() function reconnect()
function tcpreconnect()
io.write("attempting data connection... ") io.write("attempting data connection... ")
if data then data:close() end if data then data:close() end
remote [[ remote [[
@ -87,109 +87,85 @@ function tcpreconnect()
if not data then fail(err) if not data then fail(err)
else pass("connected!") end else pass("connected!") end
end end
reconnect = tcpreconnect
pass("attempting control connection...") pass("attempting control connection...")
control, err = socket.connect(host, port) control, err = socket.connect(host, port)
if err then fail(err) if err then fail(err)
else pass("connected!") end else pass("connected!") end
------------------------------------------------------------------------
test("bugs")
io.write("empty host connect: ")
function empty_connect()
if data then data:close() data = nil end
remote [[
if data then data:close() data = nil end
data = server:accept()
]]
data, err = socket.connect("", port)
if not data then
pass("ok")
data = socket.connect(host, port)
else fail("should not have connected!") end
end
empty_connect()
io.write("active close: ")
function active_close()
reconnect()
if socket._isclosed(data) then fail("should not be closed") end
data:close()
if not socket._isclosed(data) then fail("should be closed") end
data = nil
local udp = socket.udp()
if socket._isclosed(udp) then fail("should not be closed") end
udp:close()
if not socket._isclosed(udp) then fail("should be closed") end
pass("ok")
end
active_close()
------------------------------------------------------------------------ ------------------------------------------------------------------------
test("method registration") test("method registration")
function test_methods(sock, methods) function test_methods(sock, methods)
for _, v in methods do for _, v in methods do
if type(sock[v]) ~= "function" then if type(sock[v]) ~= "function" then
fail(type(sock) .. " method " .. v .. "not registered") fail(sock.class .. " method '" .. v .. "' not registered")
end end
end end
pass(type(sock) .. " methods are ok") pass(sock.class .. " methods are ok")
end end
test_methods(control, { test_methods(socket.tcp(), {
"close", "connect",
"timeout",
"send", "send",
"receive", "receive",
"bind",
"accept",
"setpeername",
"setsockname",
"getpeername", "getpeername",
"getsockname" "getsockname",
"timeout",
"close",
}) })
if udpsocket then test_methods(socket.udp(), {
test_methods(socket.udp(), { "getpeername",
"close", "getsockname",
"timeout", "setsockname",
"send", "setpeername",
"sendto", "send",
"receive", "sendto",
"receivefrom", "receive",
"getpeername", "receivefrom",
"getsockname",
"setsockname",
"setpeername"
})
end
test_methods(socket.bind("*", 0), {
"close",
"timeout", "timeout",
"accept" "close",
}) })
------------------------------------------------------------------------ ------------------------------------------------------------------------
test("select function") test("mixed patterns")
function test_selectbugs()
local r, s, e = socket.select(nil, nil, 0.1) function test_mixed(len)
assert(type(r) == "table" and type(s) == "table" and e == "timeout") reconnect()
pass("both nil: ok") local inter = math.ceil(len/4)
local udp = socket.udp() local p1 = "unix " .. string.rep("x", inter) .. "line\n"
udp:close() local p2 = "dos " .. string.rep("y", inter) .. "line\r\n"
r, s, e = socket.select({ udp }, { udp }, 0.1) local p3 = "raw " .. string.rep("z", inter) .. "bytes"
assert(type(r) == "table" and type(s) == "table" and e == "timeout") local p4 = "end" .. string.rep("w", inter) .. "bytes"
pass("closed sockets: ok") local bp1, bp2, bp3, bp4
e = pcall(socket.select, "wrong", 1, 0.1) pass(len .. " byte(s) patterns")
assert(e == false) remote (string.format("str = data:receive(%d)",
e = pcall(socket.select, {}, 1, 0.1) string.len(p1)+string.len(p2)+string.len(p3)+string.len(p4)))
assert(e == false) sent, err = data:send(p1, p2, p3, p4)
pass("invalid input: ok") if err then fail(err) end
remote "data:send(str); data:close()"
bp1, bp2, bp3, bp4, err = data:receive("*l", "*l", string.len(p3), "*a")
if err then fail(err) end
if bp1.."\n" == p1 and bp2.."\r\n" == p2 and bp3 == p3 and bp4 == p4 then
pass("patterns match")
else fail("patterns don't match") end
end end
test_selectbugs()
test_mixed(1)
test_mixed(17)
test_mixed(200)
test_mixed(4091)
test_mixed(80199)
test_mixed(4091)
test_mixed(200)
test_mixed(17)
test_mixed(1)
------------------------------------------------------------------------ ------------------------------------------------------------------------
test("character line") test("character line")
@ -202,7 +178,7 @@ function test_asciiline(len)
str = str .. str10 str = str .. str10
pass(len .. " byte(s) line") pass(len .. " byte(s) line")
remote "str = data:receive()" remote "str = data:receive()"
err = data:send(str, "\n") sent, err = data:send(str, "\n")
if err then fail(err) end if err then fail(err) end
remote "data:send(str, '\\n')" remote "data:send(str, '\\n')"
back, err = data:receive() back, err = data:receive()
@ -230,7 +206,7 @@ function test_rawline(len)
str = str .. str10 str = str .. str10
pass(len .. " byte(s) line") pass(len .. " byte(s) line")
remote "str = data:receive()" remote "str = data:receive()"
err = data:send(str, "\n") sent, err = data:send(str, "\n")
if err then fail(err) end if err then fail(err) end
remote "data:send(str, '\\n')" remote "data:send(str, '\\n')"
back, err = data:receive() back, err = data:receive()
@ -262,9 +238,9 @@ function test_raw(len)
s2 = string.rep("y", len-half) s2 = string.rep("y", len-half)
pass(len .. " byte(s) block") pass(len .. " byte(s) block")
remote (string.format("str = data:receive(%d)", len)) remote (string.format("str = data:receive(%d)", len))
err = data:send(s1) sent, err = data:send(s1)
if err then fail(err) end if err then fail(err) end
err = data:send(s2) sent, err = data:send(s2)
if err then fail(err) end if err then fail(err) end
remote "data:send(str)" remote "data:send(str)"
back, err = data:receive(len) back, err = data:receive(len)
@ -304,39 +280,139 @@ test_raw(17)
test_raw(1) test_raw(1)
------------------------------------------------------------------------ ------------------------------------------------------------------------
test("mixed patterns") test("total timeout on receive")
reconnect() function test_totaltimeoutreceive(len, tm, sl)
local str, err, total
reconnect()
pass("%d bytes, %ds total timeout, %ds pause", len, tm, sl)
remote (string.format ([[
data:timeout(%d)
str = string.rep('a', %d)
data:send(str)
print('server: sleeping for %ds')
socket.sleep(%d)
print('server: woke up')
data:send(str)
]], 2*tm, len, sl, sl))
data:timeout(tm, "total")
str, err, elapsed = data:receive(2*len)
check_timeout(tm, sl, elapsed, err, "receive", "total",
string.len(str) == 2*len)
end
test_totaltimeoutreceive(800091, 1, 3)
test_totaltimeoutreceive(800091, 2, 3)
test_totaltimeoutreceive(800091, 3, 2)
test_totaltimeoutreceive(800091, 3, 1)
function test_mixed(len) ------------------------------------------------------------------------
local inter = math.floor(len/3) test("total timeout on send")
local p1 = "unix " .. string.rep("x", inter) .. "line\n" function test_totaltimeoutsend(len, tm, sl)
local p2 = "dos " .. string.rep("y", inter) .. "line\r\n" local str, err, total
local p3 = "raw " .. string.rep("z", inter) .. "bytes" reconnect()
local bp1, bp2, bp3 pass("%d bytes, %ds total timeout, %ds pause", len, tm, sl)
pass(len .. " byte(s) patterns") remote (string.format ([[
remote (string.format("str = data:receive(%d)", data:timeout(%d)
string.len(p1)+string.len(p2)+string.len(p3))) str = data:receive(%d)
err = data:send(p1, p2, p3) print('server: sleeping for %ds')
if err then fail(err) end socket.sleep(%d)
remote "data:send(str)" print('server: woke up')
bp1, bp2, bp3, err = data:receive("*lu", "*l", string.len(p3)) str = data:receive(%d)
if err then fail(err) end ]], 2*tm, len, sl, sl, len))
if bp1.."\n" == p1 and bp2.."\r\n" == p2 and bp3 == p3 then data:timeout(tm, "total")
pass("patterns match") str = string.rep("a", 2*len)
else fail("patterns don't match") end total, err, elapsed = data:send(str)
check_timeout(tm, sl, elapsed, err, "send", "total",
total == 2*len)
end
test_totaltimeoutsend(800091, 1, 3)
test_totaltimeoutsend(800091, 2, 3)
test_totaltimeoutsend(800091, 3, 2)
test_totaltimeoutsend(800091, 3, 1)
------------------------------------------------------------------------
test("blocking timeout on receive")
function test_blockingtimeoutreceive(len, tm, sl)
local str, err, total
reconnect()
pass("%d bytes, %ds blocking timeout, %ds pause", len, tm, sl)
remote (string.format ([[
data:timeout(%d)
str = string.rep('a', %d)
data:send(str)
print('server: sleeping for %ds')
socket.sleep(%d)
print('server: woke up')
data:send(str)
]], 2*tm, len, sl, sl))
data:timeout(tm)
str, err, elapsed = data:receive(2*len)
check_timeout(tm, sl, elapsed, err, "receive", "blocking",
string.len(str) == 2*len)
end
test_blockingtimeoutreceive(800091, 1, 3)
test_blockingtimeoutreceive(800091, 2, 3)
test_blockingtimeoutreceive(800091, 3, 2)
test_blockingtimeoutreceive(800091, 3, 1)
------------------------------------------------------------------------
test("blocking timeout on send")
function test_blockingtimeoutsend(len, tm, sl)
local str, err, total
reconnect()
pass("%d bytes, %ds blocking timeout, %ds pause", len, tm, sl)
remote (string.format ([[
data:timeout(%d)
str = data:receive(%d)
print('server: sleeping for %ds')
socket.sleep(%d)
print('server: woke up')
str = data:receive(%d)
]], 2*tm, len, sl, sl, len))
data:timeout(tm)
str = string.rep("a", 2*len)
total, err, elapsed = data:send(str)
check_timeout(tm, sl, elapsed, err, "send", "blocking",
total == 2*len)
end
test_blockingtimeoutsend(800091, 1, 3)
test_blockingtimeoutsend(800091, 2, 3)
test_blockingtimeoutsend(800091, 3, 2)
test_blockingtimeoutsend(800091, 3, 1)
------------------------------------------------------------------------
test("bugs")
io.write("empty host connect: ")
function empty_connect()
if data then data:close() data = nil end
remote [[
if data then data:close() data = nil end
data = server:accept()
]]
data, err = socket.connect("", port)
if not data then
pass("ok")
data = socket.connect(host, port)
else fail("should not have connected!") end
end end
test_mixed(1) empty_connect()
test_mixed(17)
test_mixed(200) -- io.write("active close: ")
test_mixed(4091) function active_close()
test_mixed(80199) reconnect()
test_mixed(800000) if socket._isclosed(data) then fail("should not be closed") end
test_mixed(80199) data:close()
test_mixed(4091) if not socket._isclosed(data) then fail("should be closed") end
test_mixed(200) data = nil
test_mixed(17) local udp = socket.udp()
test_mixed(1) if socket._isclosed(udp) then fail("should not be closed") end
udp:close()
if not socket._isclosed(udp) then fail("should be closed") end
pass("ok")
end
-- active_close()
------------------------------------------------------------------------ ------------------------------------------------------------------------
test("closed connection detection") test("closed connection detection")
@ -363,7 +439,7 @@ function test_closed()
data:close() data:close()
data = nil data = nil
]] ]]
err, total = data:send(string.rep("ugauga", 100000)) total, err = data:send(string.rep("ugauga", 100000))
if not err then if not err then
pass("failed: output buffer is at least %d bytes long!", total) pass("failed: output buffer is at least %d bytes long!", total)
elseif err ~= "closed" then elseif err ~= "closed" then
@ -376,106 +452,26 @@ end
test_closed() test_closed()
------------------------------------------------------------------------ ------------------------------------------------------------------------
test("return timeout on receive") test("select function")
function test_blockingtimeoutreceive(len, tm, sl) function test_selectbugs()
local str, err, total local r, s, e = socket.select(nil, nil, 0.1)
reconnect() assert(type(r) == "table" and type(s) == "table" and e == "timeout")
pass("%d bytes, %ds return timeout, %ds pause", len, tm, sl) pass("both nil: ok")
remote (string.format ([[ local udp = socket.udp()
data:timeout(%d) udp:close()
str = string.rep('a', %d) r, s, e = socket.select({ udp }, { udp }, 0.1)
data:send(str) assert(type(r) == "table" and type(s) == "table" and e == "timeout")
print('server: sleeping for %ds') pass("closed sockets: ok")
socket._sleep(%d) e = pcall(socket.select, "wrong", 1, 0.1)
print('server: woke up') assert(e == false)
data:send(str) e = pcall(socket.select, {}, 1, 0.1)
]], 2*tm, len, sl, sl)) assert(e == false)
data:timeout(tm, "return") pass("invalid input: ok")
str, err, elapsed = data:receive(2*len)
check_timeout(tm, sl, elapsed, err, "receive", "return",
string.len(str) == 2*len)
end end
test_blockingtimeoutreceive(800091, 1, 3)
test_blockingtimeoutreceive(800091, 2, 3)
test_blockingtimeoutreceive(800091, 3, 2)
test_blockingtimeoutreceive(800091, 3, 1)
------------------------------------------------------------------------ -- test_selectbugs()
test("return timeout on send")
function test_returntimeoutsend(len, tm, sl)
local str, err, total
reconnect()
pass("%d bytes, %ds return timeout, %ds pause", len, tm, sl)
remote (string.format ([[
data:timeout(%d)
str = data:receive(%d)
print('server: sleeping for %ds')
socket._sleep(%d)
print('server: woke up')
str = data:receive(%d)
]], 2*tm, len, sl, sl, len))
data:timeout(tm, "return")
str = string.rep("a", 2*len)
err, total, elapsed = data:send(str)
check_timeout(tm, sl, elapsed, err, "send", "return",
total == 2*len)
end
test_returntimeoutsend(800091, 1, 3)
test_returntimeoutsend(800091, 2, 3)
test_returntimeoutsend(800091, 3, 2)
test_returntimeoutsend(800091, 3, 1)
------------------------------------------------------------------------
test("blocking timeout on receive")
function test_blockingtimeoutreceive(len, tm, sl)
local str, err, total
reconnect()
pass("%d bytes, %ds blocking timeout, %ds pause", len, tm, sl)
remote (string.format ([[
data:timeout(%d)
str = string.rep('a', %d)
data:send(str)
print('server: sleeping for %ds')
socket._sleep(%d)
print('server: woke up')
data:send(str)
]], 2*tm, len, sl, sl))
data:timeout(tm)
str, err, elapsed = data:receive(2*len)
check_timeout(tm, sl, elapsed, err, "receive", "blocking",
string.len(str) == 2*len)
end
test_blockingtimeoutreceive(800091, 1, 3)
test_blockingtimeoutreceive(800091, 2, 3)
test_blockingtimeoutreceive(800091, 3, 2)
test_blockingtimeoutreceive(800091, 3, 1)
------------------------------------------------------------------------ test(string.format("done in %.2fs", socket.time() - start))
test("blocking timeout on send")
function test_blockingtimeoutsend(len, tm, sl)
local str, err, total
reconnect()
pass("%d bytes, %ds blocking timeout, %ds pause", len, tm, sl)
remote (string.format ([[
data:timeout(%d)
str = data:receive(%d)
print('server: sleeping for %ds')
socket._sleep(%d)
print('server: woke up')
str = data:receive(%d)
]], 2*tm, len, sl, sl, len))
data:timeout(tm)
str = string.rep("a", 2*len)
err, total, elapsed = data:send(str)
check_timeout(tm, sl, elapsed, err, "send", "blocking",
total == 2*len)
end
test_blockingtimeoutsend(800091, 1, 3)
test_blockingtimeoutsend(800091, 2, 3)
test_blockingtimeoutsend(800091, 3, 2)
test_blockingtimeoutsend(800091, 3, 1)
------------------------------------------------------------------------
test(string.format("done in %.2fs", socket._time() - start))

View File

@ -13,12 +13,13 @@ while 1 do
print("server: closing connection...") print("server: closing connection...")
break break
end end
error = control:send("\n") sent, error = control:send("\n")
if error then if error then
control:close() control:close()
print("server: closing connection...") print("server: closing connection...")
break break
end end
print(command);
(loadstring(command))() (loadstring(command))()
end end
end end

View File

@ -1,5 +1,5 @@
-- load tftpclnt.lua -- load tftpclnt.lua
dofile("tftpclnt.lua") dofile("tftp.lua")
-- needs tftp server running on localhost, with root pointing to -- needs tftp server running on localhost, with root pointing to
-- a directory with index.html in it -- a directory with index.html in it
@ -13,11 +13,8 @@ function readfile(file)
end end
host = host or "localhost" host = host or "localhost"
print("downloading") retrieved, err = socket.tftp.get("tftp://" .. host .."/index.html")
err = tftp_get(host, 69, "index.html", "index.got")
assert(not err, err) assert(not err, err)
original = readfile("test/index.html") original = readfile("test/index.html")
retrieved = readfile("index.got")
os.remove("index.got")
assert(original == retrieved, "files differ!") assert(original == retrieved, "files differ!")
print("passed") print("passed")