Porting to LUA 5.0 final

This commit is contained in:
Diego Nehab 2003-05-25 01:54:13 +00:00
parent c1ef3e7103
commit 0f6c8d50a9
32 changed files with 1539 additions and 1128 deletions

25
NEW
View File

@ -1,5 +1,20 @@
Socket structures are independent
UDPBUFFERSIZE is now internal
Better treatment of closed connections: test!!!
HTTP post now deals with 1xx codes
connect, bind etc only try first address returned by resolver
All functions provided by the library are in the namespace "socket".
Functions such as send/receive/timeout/close etc do not exist in the
namespace. They are now only available as methods of the appropriate
objects.
Object has been changed to become more uniform. First create an object for
a given domain/family and protocol. Then connect or bind if needed. Then
use IO functions.
All functions return a non-nil value as first return value if successful.
All functions return nil followed by error message in case of error.
WARNING: The send function was affected.
Better error messages and parameter checking.
UDP connected udp sockets can break association with peer by calling
setpeername with address "*".
socket.sleep and socket.time are now part of the library and are
supported.

131
etc/tftp.lua Normal file
View File

@ -0,0 +1,131 @@
-----------------------------------------------------------------------------
-- TFTP support for the Lua language
-- LuaSocket 1.5 toolkit.
-- Author: Diego Nehab
-- Conforming to: RFC 783, LTN7
-- RCS ID: $Id$
-----------------------------------------------------------------------------
local Public, Private = {}, {}
local socket = _G[LUASOCKET_LIBNAME] -- get LuaSocket namespace
socket.tftp = Public -- create tftp sub namespace
-----------------------------------------------------------------------------
-- Program constants
-----------------------------------------------------------------------------
local char = string.char
local byte = string.byte
Public.PORT = 69
Private.OP_RRQ = 1
Private.OP_WRQ = 2
Private.OP_DATA = 3
Private.OP_ACK = 4
Private.OP_ERROR = 5
Private.OP_INV = {"RRQ", "WRQ", "DATA", "ACK", "ERROR"}
-----------------------------------------------------------------------------
-- Packet creation functions
-----------------------------------------------------------------------------
function Private.RRQ(source, mode)
return char(0, Private.OP_RRQ) .. source .. char(0) .. mode .. char(0)
end
function Private.WRQ(source, mode)
return char(0, Private.OP_RRQ) .. source .. char(0) .. mode .. char(0)
end
function Private.ACK(block)
local low, high
low = math.mod(block, 256)
high = (block - low)/256
return char(0, Private.OP_ACK, high, low)
end
function Private.get_OP(dgram)
local op = byte(dgram, 1)*256 + byte(dgram, 2)
return op
end
-----------------------------------------------------------------------------
-- Packet analysis functions
-----------------------------------------------------------------------------
function Private.split_DATA(dgram)
local block = byte(dgram, 3)*256 + byte(dgram, 4)
local data = string.sub(dgram, 5)
return block, data
end
function Private.get_ERROR(dgram)
local code = byte(dgram, 3)*256 + byte(dgram, 4)
local msg
_,_, msg = string.find(dgram, "(.*)\000", 5)
return string.format("error code %d: %s", code, msg)
end
-----------------------------------------------------------------------------
-- Downloads and returns a file pointed to by url
-----------------------------------------------------------------------------
function Public.get(url)
local parsed = socket.url.parse(url, {
host = "",
port = Public.PORT,
path ="/",
scheme = "tftp"
})
if parsed.scheme ~= "tftp" then
return nil, string.format("unknown scheme '%s'", parsed.scheme)
end
local retries, dgram, sent, datahost, dataport, code
local cat = socket.concat.create()
local last = 0
local udp, err = socket.udp()
if not udp then return nil, err end
-- convert from name to ip if needed
parsed.host = socket.toip(parsed.host)
udp:timeout(1)
-- first packet gives data host/port to be used for data transfers
retries = 0
repeat
sent, err = udp:sendto(Private.RRQ(parsed.path, "octet"),
parsed.host, parsed.port)
if err then return nil, err end
dgram, datahost, dataport = udp:receivefrom()
retries = retries + 1
until dgram or datahost ~= "timeout" or retries > 5
if not dgram then return nil, datahost end
-- associate socket with data host/port
udp:setpeername(datahost, dataport)
-- process all data packets
while 1 do
-- decode packet
code = Private.get_OP(dgram)
if code == Private.OP_ERROR then
return nil, Private.get_ERROR(dgram)
end
if code ~= Private.OP_DATA then
return nil, "unhandled opcode " .. code
end
-- get data packet parts
local block, data = Private.split_DATA(dgram)
-- if not repeated, write
if block == last+1 then
cat:addstring(data)
last = block
end
-- last packet brings less than 512 bytes of data
if string.len(data) < 512 then
sent, err = udp:send(Private.ACK(block))
return cat:getresult()
end
-- get the next packet
retries = 0
repeat
sent, err = udp:send(Private.ACK(last))
if err then return err end
dgram, err = udp:receive()
retries = retries + 1
until dgram or err ~= "timeout" or retries > 5
if not dgram then return err end
end
end

View File

@ -7,7 +7,8 @@ end
host = socket.toip(host)
udp = socket.udp()
print("Using host '" ..host.. "' and port " ..port.. "...")
err = udp:sendto("anything", host, port)
udp:setpeername(host, port)
sent, err = udp:send("anything")
if err then print(err) exit() end
dgram, err = udp:receive()
if not dgram then print(err) exit() end

View File

@ -5,18 +5,18 @@ if arg then
port = arg[2] or port
end
print("Attempting connection to host '" ..host.. "' and port " ..port.. "...")
c, e = connect(host, port)
c, e = socket.connect(host, port)
if not c then
print(e)
exit()
os.exit()
end
print("Connected! Please type stuff (empty line to stop):")
l = read()
l = io.read()
while l and l ~= "" and not e do
e = c:send(l, "\n")
t, e = c:send(l, "\n")
if e then
print(e)
exit()
os.exit()
end
l = read()
l = io.read()
end

113
src/auxiliar.c Normal file
View File

@ -0,0 +1,113 @@
/*=========================================================================*\
* Auxiliar routines for class hierarchy manipulation
*
* RCS ID: $Id$
\*=========================================================================*/
#include "aux.h"
/*=========================================================================*\
* Internal function prototypes
\*=========================================================================*/
static void *aux_getgroupudata(lua_State *L, const char *group, int objidx);
/*=========================================================================*\
* Exported functions
\*=========================================================================*/
/*-------------------------------------------------------------------------*\
* Creates a new class. A class has methods given by the func array and the
* field 'class' tells the object class. The table 'group' list the class
* groups the object belongs to.
\*-------------------------------------------------------------------------*/
void aux_newclass(lua_State *L, const char *name, luaL_reg *func)
{
luaL_newmetatable(L, name);
lua_pushstring(L, "__index");
lua_newtable(L);
luaL_openlib(L, NULL, func, 0);
lua_pushstring(L, "class");
lua_pushstring(L, name);
lua_settable(L, -3);
lua_settable(L, -3);
lua_pushstring(L, "group");
lua_newtable(L);
lua_settable(L, -3);
lua_pop(L, 1);
}
/*-------------------------------------------------------------------------*\
* Add group to object list of groups.
\*-------------------------------------------------------------------------*/
void aux_add2group(lua_State *L, const char *name, const char *group)
{
luaL_getmetatable(L, name);
lua_pushstring(L, "group");
lua_gettable(L, -2);
lua_pushstring(L, group);
lua_pushnumber(L, 1);
lua_settable(L, -3);
lua_pop(L, 2);
}
/*-------------------------------------------------------------------------*\
* Get a userdata making sure the object belongs to a given class.
\*-------------------------------------------------------------------------*/
void *aux_checkclass(lua_State *L, const char *name, int objidx)
{
void *data = luaL_checkudata(L, objidx, name);
if (!data) {
char msg[45];
sprintf(msg, "%.35s expected", name);
luaL_argerror(L, objidx, msg);
}
return data;
}
/*-------------------------------------------------------------------------*\
* Get a userdata making sure the object belongs to a given group.
\*-------------------------------------------------------------------------*/
void *aux_checkgroup(lua_State *L, const char *group, int objidx)
{
void *data = aux_getgroupudata(L, group, objidx);
if (!data) {
char msg[45];
sprintf(msg, "%.35s expected", group);
luaL_argerror(L, objidx, msg);
}
return data;
}
/*-------------------------------------------------------------------------*\
* Set object class.
\*-------------------------------------------------------------------------*/
void aux_setclass(lua_State *L, const char *name, int objidx)
{
luaL_getmetatable(L, name);
if (objidx < 0) objidx--;
lua_setmetatable(L, objidx);
}
/*=========================================================================*\
* Internal functions
\*=========================================================================*/
/*-------------------------------------------------------------------------*\
* Get a userdata if object belongs to a given group.
\*-------------------------------------------------------------------------*/
static void *aux_getgroupudata(lua_State *L, const char *group, int objidx)
{
if (!lua_getmetatable(L, objidx)) return NULL;
lua_pushstring(L, "group");
lua_gettable(L, -2);
if (lua_isnil(L, -1)) {
lua_pop(L, 2);
return NULL;
}
lua_pushstring(L, group);
lua_gettable(L, -2);
if (lua_isnil(L, -1)) {
lua_pop(L, 3);
return NULL;
}
lua_pop(L, 3);
return lua_touserdata(L, objidx);
}

26
src/auxiliar.h Normal file
View File

@ -0,0 +1,26 @@
/*=========================================================================*\
* Auxiliar routines for class hierarchy manipulation
*
* RCS ID: $Id$
\*=========================================================================*/
#ifndef AUX_H
#define AUX_H
#include <lua.h>
#include <lauxlib.h>
void aux_newclass(lua_State *L, const char *name, luaL_reg *func);
void aux_add2group(lua_State *L, const char *name, const char *group);
void *aux_checkclass(lua_State *L, const char *name, int objidx);
void *aux_checkgroup(lua_State *L, const char *group, int objidx);
void aux_setclass(lua_State *L, const char *name, int objidx);
/* min and max macros */
#ifndef MIN
#define MIN(x, y) ((x) < (y) ? x : y)
#endif
#ifndef MAX
#define MAX(x, y) ((x) > (y) ? x : y)
#endif
#endif

View File

@ -1,28 +1,24 @@
/*=========================================================================*\
* Buffered input/output routines
* Lua methods:
* send: unbuffered send using C base_send
* receive: buffered read using C base_receive
*
* RCS ID: $Id$
\*=========================================================================*/
#include <lua.h>
#include <lauxlib.h>
#include "lsbuf.h"
#include "error.h"
#include "aux.h"
#include "buf.h"
/*=========================================================================*\
* Internal function prototypes.
* Internal function prototypes
\*=========================================================================*/
static int sendraw(lua_State *L, p_buf buf, cchar *data, size_t len,
size_t *done);
static int recvraw(lua_State *L, p_buf buf, size_t wanted);
static int recvdosline(lua_State *L, p_buf buf);
static int recvunixline(lua_State *L, p_buf buf);
static int recvline(lua_State *L, p_buf buf);
static int recvall(lua_State *L, p_buf buf);
static int buf_contents(lua_State *L, p_buf buf, cchar **data, size_t *len);
static void buf_skip(lua_State *L, p_buf buf, size_t len);
static int buf_get(p_buf buf, const char **data, size_t *count);
static void buf_skip(p_buf buf, size_t count);
static int sendraw(p_buf buf, const char *data, size_t count, size_t *sent);
/*=========================================================================*\
* Exported functions
@ -37,98 +33,69 @@ void buf_open(lua_State *L)
/*-------------------------------------------------------------------------*\
* Initializes C structure
* Input
* buf: buffer structure to initialize
* base: socket object to associate with buffer structure
\*-------------------------------------------------------------------------*/
void buf_init(lua_State *L, p_buf buf, p_base base)
void buf_init(p_buf buf, p_io io, p_tm tm)
{
(void) L;
buf->buf_first = buf->buf_last = 0;
buf->buf_base = base;
buf->first = buf->last = 0;
buf->io = io;
buf->tm = tm;
}
/*-------------------------------------------------------------------------*\
* Send data through buffered object
* Input
* buf: buffer structure to be used
* Lua Input: self, a_1 [, a_2, a_3 ... a_n]
* self: socket object
* a_i: strings to be sent.
* Lua Returns
* On success: nil, followed by the total number of bytes sent
* On error: error message
\*-------------------------------------------------------------------------*/
int buf_send(lua_State *L, p_buf buf)
int buf_meth_send(lua_State *L, p_buf buf)
{
int top = lua_gettop(L);
size_t total = 0;
int err = PRIV_DONE;
int arg;
p_base base = buf->buf_base;
tm_markstart(&base->base_tm);
int arg, err = IO_DONE;
p_tm tm = buf->tm;
tm_markstart(tm);
for (arg = 2; arg <= top; arg++) { /* first arg is socket object */
size_t done, len;
cchar *data = luaL_optlstring(L, arg, NULL, &len);
if (!data || err != PRIV_DONE) break;
err = sendraw(L, buf, data, len, &done);
total += done;
size_t sent, count;
const char *data = luaL_optlstring(L, arg, NULL, &count);
if (!data || err != IO_DONE) break;
err = sendraw(buf, data, count, &sent);
total += sent;
}
priv_pusherror(L, err);
lua_pushnumber(L, total);
error_push(L, err);
#ifdef LUASOCKET_DEBUG
/* push time elapsed during operation as the last return value */
lua_pushnumber(L, tm_getelapsed(&base->base_tm)/1000.0);
lua_pushnumber(L, (tm_gettime() - tm_getstart(tm))/1000.0);
#endif
return lua_gettop(L) - top;
}
/*-------------------------------------------------------------------------*\
* Receive data from a buffered object
* Input
* buf: buffer structure to be used
* Lua Input: self [pat_1, pat_2 ... pat_n]
* self: socket object
* pat_i: may be one of the following
* "*l": reads a text line, defined as a string of caracters terminates
* by a LF character, preceded or not by a CR character. This is
* the default pattern
* "*lu": reads a text line, terminanted by a CR character only. (Unix mode)
* "*a": reads until connection closed
* number: reads 'number' characters from the socket object
* Lua Returns
* On success: one string for each pattern
* On error: all strings for which there was no error, followed by one
* nil value for the remaining strings, followed by an error code
\*-------------------------------------------------------------------------*/
int buf_receive(lua_State *L, p_buf buf)
int buf_meth_receive(lua_State *L, p_buf buf)
{
int top = lua_gettop(L);
int arg, err = PRIV_DONE;
p_base base = buf->buf_base;
tm_markstart(&base->base_tm);
int arg, err = IO_DONE;
p_tm tm = buf->tm;
tm_markstart(tm);
/* push default pattern if need be */
if (top < 2) {
lua_pushstring(L, "*l");
top++;
}
/* make sure we have enough stack space */
/* make sure we have enough stack space for all returns */
luaL_checkstack(L, top+LUA_MINSTACK, "too many arguments");
/* receive all patterns */
for (arg = 2; arg <= top && err == PRIV_DONE; arg++) {
for (arg = 2; arg <= top && err == IO_DONE; arg++) {
if (!lua_isnumber(L, arg)) {
static cchar *patternnames[] = {"*l", "*lu", "*a", "*w", NULL};
cchar *pattern = luaL_optstring(L, arg, NULL);
static const char *patternnames[] = {"*l", "*a", NULL};
const char *pattern = lua_isnil(L, arg) ?
"*l" : luaL_checkstring(L, arg);
/* get next pattern */
switch (luaL_findstring(pattern, patternnames)) {
case 0: /* DOS line pattern */
err = recvdosline(L, buf); break;
case 1: /* Unix line pattern */
err = recvunixline(L, buf); break;
case 2: /* Until closed pattern */
err = recvall(L, buf); break;
case 3: /* Word pattern */
luaL_argcheck(L, 0, arg, "word patterns are deprecated");
case 0: /* line pattern */
err = recvline(L, buf); break;
case 1: /* until closed pattern */
err = recvall(L, buf);
if (err == IO_CLOSED) err = IO_DONE;
break;
default: /* else it is an error */
luaL_argcheck(L, 0, arg, "invalid receive pattern");
@ -140,25 +107,20 @@ int buf_receive(lua_State *L, p_buf buf)
/* push nil for each pattern after an error */
for ( ; arg <= top; arg++) lua_pushnil(L);
/* last return is an error code */
priv_pusherror(L, err);
error_push(L, err);
#ifdef LUASOCKET_DEBUG
/* push time elapsed during operation as the last return value */
lua_pushnumber(L, tm_getelapsed(&base->base_tm)/1000.0);
lua_pushnumber(L, (tm_gettime() - tm_getstart(tm))/1000.0);
#endif
return lua_gettop(L) - top;
}
/*-------------------------------------------------------------------------*\
* Determines if there is any data in the read buffer
* Input
* buf: buffer structure to be used
* Returns
* 1 if empty, 0 if there is data
\*-------------------------------------------------------------------------*/
int buf_isempty(lua_State *L, p_buf buf)
int buf_isempty(p_buf buf)
{
(void) L;
return buf->buf_first >= buf->buf_last;
return buf->first >= buf->last;
}
/*=========================================================================*\
@ -166,24 +128,16 @@ int buf_isempty(lua_State *L, p_buf buf)
\*=========================================================================*/
/*-------------------------------------------------------------------------*\
* Sends a raw block of data through a buffered object.
* Input
* buf: buffer structure to be used
* data: data to be sent
* len: number of bytes to send
* Output
* sent: number of bytes sent
* Returns
* operation error code.
\*-------------------------------------------------------------------------*/
static int sendraw(lua_State *L, p_buf buf, cchar *data, size_t len,
size_t *sent)
static int sendraw(p_buf buf, const char *data, size_t count, size_t *sent)
{
p_base base = buf->buf_base;
p_io io = buf->io;
p_tm tm = buf->tm;
size_t total = 0;
int err = PRIV_DONE;
while (total < len && err == PRIV_DONE) {
int err = IO_DONE;
while (total < count && err == IO_DONE) {
size_t done;
err = base->base_send(L, base, data + total, len - total, &done);
err = io->send(io->ctx, data+total, count-total, &done, tm_get(tm));
total += done;
}
*sent = total;
@ -192,25 +146,21 @@ static int sendraw(lua_State *L, p_buf buf, cchar *data, size_t len,
/*-------------------------------------------------------------------------*\
* Reads a raw block of data from a buffered object.
* Input
* buf: buffer structure
* wanted: number of bytes to be read
* Returns
* operation error code.
\*-------------------------------------------------------------------------*/
static int recvraw(lua_State *L, p_buf buf, size_t wanted)
static
int recvraw(lua_State *L, p_buf buf, size_t wanted)
{
int err = PRIV_DONE;
int err = IO_DONE;
size_t total = 0;
luaL_Buffer b;
luaL_buffinit(L, &b);
while (total < wanted && err == PRIV_DONE) {
size_t len; cchar *data;
err = buf_contents(L, buf, &data, &len);
len = MIN(len, wanted - total);
luaL_addlstring(&b, data, len);
buf_skip(L, buf, len);
total += len;
while (total < wanted && err == IO_DONE) {
size_t count; const char *data;
err = buf_get(buf, &data, &count);
count = MIN(count, wanted - total);
luaL_addlstring(&b, data, count);
buf_skip(buf, count);
total += count;
}
luaL_pushresult(&b);
return err;
@ -218,21 +168,18 @@ static int recvraw(lua_State *L, p_buf buf, size_t wanted)
/*-------------------------------------------------------------------------*\
* Reads everything until the connection is closed
* Input
* buf: buffer structure
* Result
* operation error code.
\*-------------------------------------------------------------------------*/
static int recvall(lua_State *L, p_buf buf)
static
int recvall(lua_State *L, p_buf buf)
{
int err = PRIV_DONE;
int err = IO_DONE;
luaL_Buffer b;
luaL_buffinit(L, &b);
while (err == PRIV_DONE) {
cchar *data; size_t len;
err = buf_contents(L, buf, &data, &len);
luaL_addlstring(&b, data, len);
buf_skip(L, buf, len);
while (err == IO_DONE) {
const char *data; size_t count;
err = buf_get(buf, &data, &count);
luaL_addlstring(&b, data, count);
buf_skip(buf, count);
}
luaL_pushresult(&b);
return err;
@ -241,61 +188,27 @@ static int recvall(lua_State *L, p_buf buf)
/*-------------------------------------------------------------------------*\
* Reads a line terminated by a CR LF pair or just by a LF. The CR and LF
* are not returned by the function and are discarded from the buffer.
* Input
* buf: buffer structure
* Result
* operation error code. PRIV_DONE, PRIV_TIMEOUT or PRIV_CLOSED
\*-------------------------------------------------------------------------*/
static int recvdosline(lua_State *L, p_buf buf)
static
int recvline(lua_State *L, p_buf buf)
{
int err = 0;
luaL_Buffer b;
luaL_buffinit(L, &b);
while (err == PRIV_DONE) {
size_t len, pos; cchar *data;
err = buf_contents(L, buf, &data, &len);
while (err == IO_DONE) {
size_t count, pos; const char *data;
err = buf_get(buf, &data, &count);
pos = 0;
while (pos < len && data[pos] != '\n') {
while (pos < count && data[pos] != '\n') {
/* we ignore all \r's */
if (data[pos] != '\r') luaL_putchar(&b, data[pos]);
pos++;
}
if (pos < len) { /* found '\n' */
buf_skip(L, buf, pos+1); /* skip '\n' too */
if (pos < count) { /* found '\n' */
buf_skip(buf, pos+1); /* skip '\n' too */
break; /* we are done */
} else /* reached the end of the buffer */
buf_skip(L, buf, pos);
}
luaL_pushresult(&b);
return err;
}
/*-------------------------------------------------------------------------*\
* Reads a line terminated by a LF character, which is not returned by
* the function, and is skipped in the buffer.
* Input
* buf: buffer structure
* Returns
* operation error code. PRIV_DONE, PRIV_TIMEOUT or PRIV_CLOSED
\*-------------------------------------------------------------------------*/
static int recvunixline(lua_State *L, p_buf buf)
{
int err = PRIV_DONE;
luaL_Buffer b;
luaL_buffinit(L, &b);
while (err == 0) {
size_t pos, len; cchar *data;
err = buf_contents(L, buf, &data, &len);
pos = 0;
while (pos < len && data[pos] != '\n') {
luaL_putchar(&b, data[pos]);
pos++;
}
if (pos < len) { /* found '\n' */
buf_skip(L, buf, pos+1); /* skip '\n' too */
break; /* we are done */
} else /* reached the end of the buffer */
buf_skip(L, buf, pos);
buf_skip(buf, pos);
}
luaL_pushresult(&b);
return err;
@ -303,38 +216,32 @@ static int recvunixline(lua_State *L, p_buf buf)
/*-------------------------------------------------------------------------*\
* Skips a given number of bytes in read buffer
* Input
* buf: buffer structure
* len: number of bytes to skip
\*-------------------------------------------------------------------------*/
static void buf_skip(lua_State *L, p_buf buf, size_t len)
static
void buf_skip(p_buf buf, size_t count)
{
buf->buf_first += len;
if (buf_isempty(L, buf)) buf->buf_first = buf->buf_last = 0;
buf->first += count;
if (buf_isempty(buf))
buf->first = buf->last = 0;
}
/*-------------------------------------------------------------------------*\
* Return any data available in buffer, or get more data from transport layer
* if buffer is empty.
* Input
* buf: buffer structure
* Output
* data: pointer to buffer start
* len: buffer buffer length
* Returns
* PRIV_DONE, PRIV_CLOSED, PRIV_TIMEOUT ...
\*-------------------------------------------------------------------------*/
static int buf_contents(lua_State *L, p_buf buf, cchar **data, size_t *len)
static
int buf_get(p_buf buf, const char **data, size_t *count)
{
int err = PRIV_DONE;
p_base base = buf->buf_base;
if (buf_isempty(L, buf)) {
size_t done;
err = base->base_receive(L, base, buf->buf_data, BUF_SIZE, &done);
buf->buf_first = 0;
buf->buf_last = done;
int err = IO_DONE;
p_io io = buf->io;
p_tm tm = buf->tm;
if (buf_isempty(buf)) {
size_t got;
err = io->recv(io->ctx, buf->data, BUF_SIZE, &got, tm_get(tm));
buf->first = 0;
buf->last = got;
}
*len = buf->buf_last - buf->buf_first;
*data = buf->buf_data + buf->buf_first;
*count = buf->last - buf->first;
*data = buf->data + buf->first;
return err;
}

View File

@ -3,11 +3,12 @@
*
* RCS ID: $Id$
\*=========================================================================*/
#ifndef BUF_H_
#define BUF_H_
#ifndef BUF_H
#define BUF_H
#include <lua.h>
#include "lsbase.h"
#include "io.h"
#include "tm.h"
/* buffer size in bytes */
#define BUF_SIZE 8192
@ -15,10 +16,11 @@
/*-------------------------------------------------------------------------*\
* Buffer control structure
\*-------------------------------------------------------------------------*/
typedef struct t_buf_tag {
size_t buf_first, buf_last;
char buf_data[BUF_SIZE];
p_base buf_base;
typedef struct t_buf_ {
p_io io; /* IO driver used for this buffer */
p_tm tm; /* timeout management for this buffer */
size_t first, last; /* index of first and last bytes of stored data */
char data[BUF_SIZE]; /* storage space for buffer data */
} t_buf;
typedef t_buf *p_buf;
@ -26,9 +28,9 @@ typedef t_buf *p_buf;
* Exported functions
\*-------------------------------------------------------------------------*/
void buf_open(lua_State *L);
void buf_init(lua_State *L, p_buf buf, p_base base);
int buf_send(lua_State *L, p_buf buf);
int buf_receive(lua_State *L, p_buf buf);
int buf_isempty(lua_State *L, p_buf buf);
void buf_init(p_buf buf, p_io io, p_tm tm);
int buf_meth_send(lua_State *L, p_buf buf);
int buf_meth_receive(lua_State *L, p_buf buf);
int buf_isempty(p_buf buf);
#endif /* BUF_H_ */
#endif /* BUF_H */

View File

@ -7,7 +7,8 @@
-----------------------------------------------------------------------------
local Public, Private = {}, {}
socket.ftp = Public
local socket = _G[LUASOCKET_LIBNAME] -- get LuaSocket namespace
socket.ftp = Public -- create ftp sub namespace
-----------------------------------------------------------------------------
-- Program constants
@ -22,6 +23,33 @@ Public.EMAIL = "anonymous@anonymous.org"
-- block size used in transfers
Public.BLOCKSIZE = 8192
-----------------------------------------------------------------------------
-- Tries to get a pattern from the server and closes socket on error
-- sock: socket connected to the server
-- pattern: pattern to receive
-- Returns
-- received pattern on success
-- nil followed by error message on error
-----------------------------------------------------------------------------
function Private.try_receive(sock, pattern)
local data, err = sock:receive(pattern)
if not data then sock:close() end
return data, err
end
-----------------------------------------------------------------------------
-- Tries to send data to the server and closes socket on error
-- sock: socket connected to the server
-- data: data to send
-- Returns
-- err: error message if any, nil if successfull
-----------------------------------------------------------------------------
function Private.try_send(sock, data)
local sent, err = sock:send(data)
if not sent then sock:close() end
return err
end
-----------------------------------------------------------------------------
-- Tries to send DOS mode lines. Closes socket on error.
-- Input
@ -31,24 +59,7 @@ Public.BLOCKSIZE = 8192
-- err: message in case of error, nil if successfull
-----------------------------------------------------------------------------
function Private.try_sendline(sock, line)
local err = sock:send(line .. "\r\n")
if err then sock:close() end
return err
end
-----------------------------------------------------------------------------
-- Tries to get a pattern from the server and closes socket on error
-- sock: socket connected to the server
-- ...: pattern to receive
-- Returns
-- ...: received pattern
-- err: error message if any
-----------------------------------------------------------------------------
function Private.try_receive(...)
local sock = arg[1]
local data, err = sock.receive(unpack(arg))
if err then sock:close() end
return data, err
return Private.try_send(sock, line .. "\r\n")
end
-----------------------------------------------------------------------------
@ -307,20 +318,20 @@ end
-- nil if successfull, or an error message in case of error
-----------------------------------------------------------------------------
function Private.send_indirect(data, send_cb, chunk, size)
local sent, err
sent = 0
local total, sent, err
total = 0
while 1 do
if type(chunk) ~= "string" or type(size) ~= "number" then
data:close()
if not chunk and type(size) == "string" then return size
else return "invalid callback return" end
end
err = data:send(chunk)
sent, err = data:send(chunk)
if err then
data:close()
return err
end
sent = sent + string.len(chunk)
total = total + sent
if sent >= size then break end
chunk, size = send_cb()
end

View File

@ -7,7 +7,8 @@
-----------------------------------------------------------------------------
local Public, Private = {}, {}
socket.http = Public
local socket = _G[LUASOCKET_LIBNAME] -- get LuaSocket namespace
socket.http = Public -- create http sub namespace
-----------------------------------------------------------------------------
-- Program constants
@ -24,19 +25,15 @@ Public.BLOCKSIZE = 8192
-----------------------------------------------------------------------------
-- Tries to get a pattern from the server and closes socket on error
-- sock: socket connected to the server
-- ...: pattern to receive
-- pattern: pattern to receive
-- Returns
-- ...: received pattern
-- err: error message if any
-- received pattern on success
-- nil followed by error message on error
-----------------------------------------------------------------------------
function Private.try_receive(...)
local sock = arg[1]
local data, err = sock.receive(unpack(arg))
if err then
sock:close()
return nil, err
end
return data
function Private.try_receive(sock, pattern)
local data, err = sock:receive(pattern)
if not data then sock:close() end
return data, err
end
-----------------------------------------------------------------------------
@ -47,8 +44,8 @@ end
-- err: error message if any, nil if successfull
-----------------------------------------------------------------------------
function Private.try_send(sock, data)
local err = sock:send(data)
if err then sock:close() end
local sent, err = sock:send(data)
if not sent then sock:close() end
return err
end
@ -285,21 +282,21 @@ end
-- nil if successfull, or an error message in case of error
-----------------------------------------------------------------------------
function Private.send_indirect(data, send_cb, chunk, size)
local sent, err
sent = 0
local total, sent, err
total = 0
while 1 do
if type(chunk) ~= "string" or type(size) ~= "number" then
data:close()
if not chunk and type(size) == "string" then return size
else return "invalid callback return" end
end
err = data:send(chunk)
sent, err = data:send(chunk)
if err then
data:close()
return err
end
sent = sent + string.len(chunk)
if sent >= size then break end
total = total + sent
if total >= size then break end
chunk, size = send_cb()
end
end

View File

@ -1,12 +1,5 @@
/*=========================================================================*\
* Internet domain class: inherits from the Socket class, and implement
* a few methods shared by all internet related objects
* Lua methods:
* getpeername: gets socket peer ip address and port
* getsockname: gets local socket ip address and port
* Global Lua fuctions:
* toip: gets resolver info on host name
* tohostname: gets resolver info on dotted-quad
* Internet domain functions
*
* RCS ID: $Id$
\*=========================================================================*/
@ -15,23 +8,27 @@
#include <lua.h>
#include <lauxlib.h>
#include "lsinet.h"
#include "lssock.h"
#include "lscompat.h"
#include "luasocket.h"
#include "inet.h"
/*=========================================================================*\
* Internal function prototypes.
\*=========================================================================*/
static int inet_lua_toip(lua_State *L);
static int inet_lua_tohostname(lua_State *L);
static int inet_lua_getpeername(lua_State *L);
static int inet_lua_getsockname(lua_State *L);
static int inet_global_toip(lua_State *L);
static int inet_global_tohostname(lua_State *L);
static void inet_pushresolved(lua_State *L, struct hostent *hp);
#ifdef COMPAT_INETATON
static int inet_aton(cchar *cp, struct in_addr *inp);
#ifdef INET_ATON
static int inet_aton(const char *cp, struct in_addr *inp);
#endif
static luaL_reg func[] = {
{ "toip", inet_global_toip },
{ "tohostname", inet_global_tohostname },
{ NULL, NULL}
};
/*=========================================================================*\
* Exported functions
\*=========================================================================*/
@ -40,39 +37,7 @@ static int inet_aton(cchar *cp, struct in_addr *inp);
\*-------------------------------------------------------------------------*/
void inet_open(lua_State *L)
{
lua_pushcfunction(L, inet_lua_toip);
priv_newglobal(L, "toip");
lua_pushcfunction(L, inet_lua_tohostname);
priv_newglobal(L, "tohostname");
priv_newglobalmethod(L, "getsockname");
priv_newglobalmethod(L, "getpeername");
}
/*-------------------------------------------------------------------------*\
* Hook lua methods to methods table.
* Input
* lsclass: class name
\*-------------------------------------------------------------------------*/
void inet_inherit(lua_State *L, cchar *lsclass)
{
unsigned int i;
static struct luaL_reg funcs[] = {
{"getsockname", inet_lua_getsockname},
{"getpeername", inet_lua_getpeername},
};
sock_inherit(L, lsclass);
for (i = 0; i < sizeof(funcs)/sizeof(funcs[0]); i++) {
lua_pushcfunction(L, funcs[i].func);
priv_setmethod(L, lsclass, funcs[i].name);
}
}
/*-------------------------------------------------------------------------*\
* Constructs the object
\*-------------------------------------------------------------------------*/
void inet_construct(lua_State *L, p_inet inet)
{
sock_construct(L, (p_sock) inet);
luaL_openlib(L, LUASOCKET_LIBNAME, func, 0);
}
/*=========================================================================*\
@ -87,17 +52,18 @@ void inet_construct(lua_State *L, p_inet inet)
* On success: first IP address followed by a resolved table
* On error: nil, followed by an error message
\*-------------------------------------------------------------------------*/
static int inet_lua_toip(lua_State *L)
static int inet_global_toip(lua_State *L)
{
cchar *address = luaL_checkstring(L, 1);
const char *address = luaL_checkstring(L, 1);
struct in_addr addr;
struct hostent *hp;
if (inet_aton(address, &addr))
hp = gethostbyaddr((char *) &addr, sizeof(addr), AF_INET);
else hp = gethostbyname(address);
else
hp = gethostbyname(address);
if (!hp) {
lua_pushnil(L);
lua_pushstring(L, compat_hoststrerror());
lua_pushstring(L, sock_hoststrerror());
return 2;
}
addr = *((struct in_addr *) hp->h_addr);
@ -115,17 +81,18 @@ static int inet_lua_toip(lua_State *L)
* On success: canonic name followed by a resolved table
* On error: nil, followed by an error message
\*-------------------------------------------------------------------------*/
static int inet_lua_tohostname(lua_State *L)
static int inet_global_tohostname(lua_State *L)
{
cchar *address = luaL_checkstring(L, 1);
const char *address = luaL_checkstring(L, 1);
struct in_addr addr;
struct hostent *hp;
if (inet_aton(address, &addr))
hp = gethostbyaddr((char *) &addr, sizeof(addr), AF_INET);
else hp = gethostbyname(address);
else
hp = gethostbyname(address);
if (!hp) {
lua_pushnil(L);
lua_pushstring(L, compat_hoststrerror());
lua_pushstring(L, sock_hoststrerror());
return 2;
}
lua_pushstring(L, hp->h_name);
@ -138,18 +105,17 @@ static int inet_lua_tohostname(lua_State *L)
\*=========================================================================*/
/*-------------------------------------------------------------------------*\
* Retrieves socket peer name
* Lua Input: sock
* Input:
* sock: socket
* Lua Returns
* On success: ip address and port of peer
* On error: nil
\*-------------------------------------------------------------------------*/
static int inet_lua_getpeername(lua_State *L)
int inet_meth_getpeername(lua_State *L, p_sock ps)
{
p_sock sock = (p_sock) lua_touserdata(L, 1);
struct sockaddr_in peer;
size_t peer_len = sizeof(peer);
if (getpeername(sock->fd, (SA *) &peer, &peer_len) < 0) {
if (getpeername(*ps, (SA *) &peer, &peer_len) < 0) {
lua_pushnil(L);
return 1;
}
@ -160,18 +126,17 @@ static int inet_lua_getpeername(lua_State *L)
/*-------------------------------------------------------------------------*\
* Retrieves socket local name
* Lua Input: sock
* Input:
* sock: socket
* Lua Returns
* On success: local ip address and port
* On error: nil
\*-------------------------------------------------------------------------*/
static int inet_lua_getsockname(lua_State *L)
int inet_meth_getsockname(lua_State *L, p_sock ps)
{
p_sock sock = (p_sock) lua_touserdata(L, 1);
struct sockaddr_in local;
size_t local_len = sizeof(local);
if (getsockname(sock->fd, (SA *) &local, &local_len) < 0) {
if (getsockname(*ps, (SA *) &local, &local_len) < 0) {
lua_pushnil(L);
return 1;
}
@ -222,47 +187,53 @@ static void inet_pushresolved(lua_State *L, struct hostent *hp)
}
/*-------------------------------------------------------------------------*\
* Tries to create a TCP socket and connect to remote address (address, port)
* Tries to connect to remote address (address, port)
* Input
* client: socket structure to be used
* ps: pointer to socket
* address: host name or ip address
* port: port number to bind to
* Returns
* NULL in case of success, error message otherwise
\*-------------------------------------------------------------------------*/
cchar *inet_tryconnect(p_inet inet, cchar *address, ushort port)
const char *inet_tryconnect(p_sock ps, const char *address, ushort port)
{
struct sockaddr_in remote;
memset(&remote, 0, sizeof(remote));
remote.sin_family = AF_INET;
remote.sin_port = htons(port);
if (!strlen(address) || !inet_aton(address, &remote.sin_addr)) {
struct hostent *hp = gethostbyname(address);
struct in_addr **addr;
if (!hp) return compat_hoststrerror();
addr = (struct in_addr **) hp->h_addr_list;
memcpy(&remote.sin_addr, *addr, sizeof(struct in_addr));
}
compat_setblocking(inet->fd);
if (compat_connect(inet->fd, (SA *) &remote, sizeof(remote)) < 0) {
const char *err = compat_connectstrerror();
compat_close(inet->fd);
inet->fd = COMPAT_INVALIDFD;
if (strcmp(address, "*")) {
if (!strlen(address) || !inet_aton(address, &remote.sin_addr)) {
struct hostent *hp = gethostbyname(address);
struct in_addr **addr;
remote.sin_family = AF_INET;
if (!hp) return sock_hoststrerror();
addr = (struct in_addr **) hp->h_addr_list;
memcpy(&remote.sin_addr, *addr, sizeof(struct in_addr));
}
} else remote.sin_family = AF_UNSPEC;
sock_setblocking(ps);
const char *err = sock_connect(ps, (SA *) &remote, sizeof(remote));
if (err) {
sock_destroy(ps);
*ps = SOCK_INVALID;
return err;
} else {
sock_setnonblocking(ps);
return NULL;
}
compat_setnonblocking(inet->fd);
return NULL;
}
/*-------------------------------------------------------------------------*\
* Tries to create a TCP socket and bind it to (address, port)
* Tries to bind socket to (address, port)
* Input
* sock: pointer to socket
* address: host name or ip address
* port: port number to bind to
* Returns
* NULL in case of success, error message otherwise
\*-------------------------------------------------------------------------*/
cchar *inet_trybind(p_inet inet, cchar *address, ushort port)
const char *inet_trybind(p_sock ps, const char *address, ushort port,
int backlog)
{
struct sockaddr_in local;
memset(&local, 0, sizeof(local));
@ -274,34 +245,33 @@ cchar *inet_trybind(p_inet inet, cchar *address, ushort port)
(!strlen(address) || !inet_aton(address, &local.sin_addr))) {
struct hostent *hp = gethostbyname(address);
struct in_addr **addr;
if (!hp) return compat_hoststrerror();
if (!hp) return sock_hoststrerror();
addr = (struct in_addr **) hp->h_addr_list;
memcpy(&local.sin_addr, *addr, sizeof(struct in_addr));
}
compat_setblocking(inet->fd);
if (compat_bind(inet->fd, (SA *) &local, sizeof(local)) < 0) {
const char *err = compat_bindstrerror();
compat_close(inet->fd);
inet->fd = COMPAT_INVALIDFD;
sock_setblocking(ps);
const char *err = sock_bind(ps, (SA *) &local, sizeof(local));
if (err) {
sock_destroy(ps);
*ps = SOCK_INVALID;
return err;
} else {
sock_setnonblocking(ps);
if (backlog > 0) sock_listen(ps, backlog);
return NULL;
}
compat_setnonblocking(inet->fd);
return NULL;
}
/*-------------------------------------------------------------------------*\
* Tries to create a new inet socket
* Input
* udp: udp structure
* sock: pointer to socket
* Returns
* NULL if successfull, error message on error
\*-------------------------------------------------------------------------*/
cchar *inet_trysocket(p_inet inet, int type)
const char *inet_trycreate(p_sock ps, int type)
{
if (inet->fd != COMPAT_INVALIDFD) compat_close(inet->fd);
inet->fd = compat_socket(AF_INET, type, 0);
if (inet->fd == COMPAT_INVALIDFD) return compat_socketstrerror();
else return NULL;
return sock_create(ps, AF_INET, type, 0);
}
/*-------------------------------------------------------------------------*\

View File

@ -1,38 +1,26 @@
/*=========================================================================*\
* Internet domain class: inherits from the Socket class, and implement
* a few methods shared by all internet related objects
* Internet domain functions
*
* RCS ID: $Id$
\*=========================================================================*/
#ifndef INET_H_
#define INET_H_
#ifndef INET_H
#define INET_H
#include <lua.h>
#include "lssock.h"
/* class name */
#define INET_CLASS "luasocket(inet)"
/*-------------------------------------------------------------------------*\
* Socket fields
\*-------------------------------------------------------------------------*/
#define INET_FIELDS SOCK_FIELDS
/*-------------------------------------------------------------------------*\
* Socket structure
\*-------------------------------------------------------------------------*/
typedef t_sock t_inet;
typedef t_inet *p_inet;
#include "sock.h"
/*-------------------------------------------------------------------------*\
* Exported functions
\*-------------------------------------------------------------------------*/
void inet_open(lua_State *L);
void inet_construct(lua_State *L, p_inet inet);
void inet_inherit(lua_State *L, cchar *lsclass);
cchar *inet_tryconnect(p_sock sock, cchar *address, ushort);
cchar *inet_trybind(p_sock sock, cchar *address, ushort);
cchar *inet_trysocket(p_inet inet, int type);
const char *inet_tryconnect(p_sock ps, const char *address,
unsigned short port);
const char *inet_trybind(p_sock ps, const char *address,
unsigned short port, int backlog);
const char *inet_trycreate(p_sock ps, int type);
int inet_meth_getpeername(lua_State *L, p_sock ps);
int inet_meth_getsockname(lua_State *L, p_sock ps);
#endif /* INET_H_ */

8
src/io.c Normal file
View File

@ -0,0 +1,8 @@
#include "io.h"
void io_init(p_io io, p_send send, p_recv recv, void *ctx)
{
io->send = send;
io->recv = recv;
io->ctx = ctx;
}

34
src/io.h Normal file
View File

@ -0,0 +1,34 @@
#ifndef IO_H
#define IO_H
#include "error.h"
/* interface to send function */
typedef int (*p_send) (
void *ctx, /* context needed by send */
const char *data, /* pointer to buffer with data to send */
size_t count, /* number of bytes to send from buffer */
size_t *sent, /* number of bytes sent uppon return */
int timeout /* number of miliseconds left for transmission */
);
/* interface to recv function */
typedef int (*p_recv) (
void *ctx, /* context needed by recv */
char *data, /* pointer to buffer where data will be writen */
size_t count, /* number of bytes to receive into buffer */
size_t *got, /* number of bytes received uppon return */
int timeout /* number of miliseconds left for transmission */
);
/* IO driver definition */
typedef struct t_io_ {
void *ctx; /* context needed by send/recv */
p_send send; /* send function pointer */
p_recv recv; /* receive function pointer */
} t_io;
typedef t_io *p_io;
void io_init(p_io io, p_send send, p_recv recv, void *ctx);
#endif /* IO_H */

View File

@ -23,18 +23,13 @@
* LuaSocket includes
\*=========================================================================*/
#include "luasocket.h"
#include "lspriv.h"
#include "lsselect.h"
#include "lscompat.h"
#include "lsbase.h"
#include "lstm.h"
#include "lsbuf.h"
#include "lssock.h"
#include "lsinet.h"
#include "lstcpc.h"
#include "lstcps.h"
#include "lstcps.h"
#include "lsudp.h"
#include "tm.h"
#include "buf.h"
#include "sock.h"
#include "inet.h"
#include "tcp.h"
#include "udp.h"
/*=========================================================================*\
* Exported functions
@ -42,34 +37,29 @@
/*-------------------------------------------------------------------------*\
* Initializes all library modules.
\*-------------------------------------------------------------------------*/
LUASOCKET_API int lua_socketlibopen(lua_State *L)
LUASOCKET_API int luaopen_socketlib(lua_State *L)
{
compat_open(L);
priv_open(L);
select_open(L);
base_open(L);
tm_open(L);
fd_open(L);
sock_open(L);
inet_open(L);
tcpc_open(L);
buf_open(L);
tcps_open(L);
udp_open(L);
#ifdef LUASOCKET_DOFILE
lua_dofile(L, "concat.lua");
lua_dofile(L, "code.lua");
lua_dofile(L, "url.lua");
lua_dofile(L, "http.lua");
lua_dofile(L, "smtp.lua");
lua_dofile(L, "ftp.lua");
#else
#include "concat.loh"
#include "code.loh"
#include "url.loh"
#include "http.loh"
#include "smtp.loh"
#include "ftp.loh"
/* create namespace table */
lua_pushstring(L, LUASOCKET_LIBNAME);
lua_newtable(L);
#ifdef LUASOCKET_DEBUG
lua_pushstring(L, "debug");
lua_pushnumber(L, 1);
lua_settable(L, -3);
#endif
lua_settable(L, LUA_GLOBALSINDEX);
/* make sure modules know what is our namespace */
lua_pushstring(L, "LUASOCKET_LIBNAME");
lua_pushstring(L, LUASOCKET_LIBNAME);
lua_settable(L, LUA_GLOBALSINDEX);
/* initialize all modules */
sock_open(L);
tm_open(L);
buf_open(L);
inet_open(L);
tcp_open(L);
udp_open(L);
/* load all Lua code */
lua_dofile(L, "luasocket.lua");
return 0;
}

View File

@ -5,8 +5,8 @@
*
* RCS ID: $Id$
\*=========================================================================*/
#ifndef _LUASOCKET_H_
#define _LUASOCKET_H_
#ifndef LUASOCKET_H
#define LUASOCKET_H
/*-------------------------------------------------------------------------*\
* Current luasocket version
@ -28,6 +28,6 @@
/*-------------------------------------------------------------------------*\
* Initializes the library.
\*-------------------------------------------------------------------------*/
LUASOCKET_API int lua_socketlibopen(lua_State *L);
LUASOCKET_API int luaopen_socketlib(lua_State *L);
#endif /* _LUASOCKET_H_ */
#endif /* LUASOCKET_H */

View File

@ -5,10 +5,10 @@ mbox = Public
function Public.split_message(message_s)
local message = {}
message_s = string.gsub(message_s, "\r\n", "\n")
string.gsub(message_s, "^(.-\n)\n", function (h) %message.headers = h end)
string.gsub(message_s, "^.-\n\n(.*)", function (b) %message.body = b end)
string.gsub(message_s, "^(.-\n)\n", function (h) message.headers = h end)
string.gsub(message_s, "^.-\n\n(.*)", function (b) message.body = b end)
if not message.body then
string.gsub(message_s, "^\n(.*)", function (b) %message.body = b end)
string.gsub(message_s, "^\n(.*)", function (b) message.body = b end)
end
if not message.headers and not message.body then
message.headers = message_s
@ -20,7 +20,7 @@ function Public.split_headers(headers_s)
local headers = {}
headers_s = string.gsub(headers_s, "\r\n", "\n")
headers_s = string.gsub(headers_s, "\n[ ]+", " ")
string.gsub("\n" .. headers_s, "\n([^\n]+)", function (h) table.insert(%headers, h) end)
string.gsub("\n" .. headers_s, "\n([^\n]+)", function (h) table.insert(headers, h) end)
return headers
end
@ -32,10 +32,10 @@ function Public.parse_header(header_s)
end
function Public.parse_headers(headers_s)
local headers_t = %Public.split_headers(headers_s)
local headers_t = Public.split_headers(headers_s)
local headers = {}
for i = 1, table.getn(headers_t) do
local name, value = %Public.parse_header(headers_t[i])
local name, value = Public.parse_header(headers_t[i])
if name then
name = string.lower(name)
if headers[name] then
@ -73,16 +73,16 @@ function Public.split_mbox(mbox_s)
end
function Public.parse(mbox_s)
local mbox = %Public.split_mbox(mbox_s)
local mbox = Public.split_mbox(mbox_s)
for i = 1, table.getn(mbox) do
mbox[i] = %Public.parse_message(mbox[i])
mbox[i] = Public.parse_message(mbox[i])
end
return mbox
end
function Public.parse_message(message_s)
local message = {}
message.headers, message.body = %Public.split_message(message_s)
message.headers = %Public.parse_headers(message.headers)
message.headers, message.body = Public.split_message(message_s)
message.headers = Public.parse_headers(message.headers)
return message
end

View File

@ -7,7 +7,8 @@
-----------------------------------------------------------------------------
local Public, Private = {}, {}
socket.smtp = Public
local socket = _G[LUASOCKET_LIBNAME] -- get LuaSocket namespace
socket.smtp = Public -- create smtp sub namespace
-----------------------------------------------------------------------------
-- Program constants
@ -23,32 +24,30 @@ Public.DOMAIN = os.getenv("SERVER_NAME") or "localhost"
Public.SERVER = "localhost"
-----------------------------------------------------------------------------
-- Tries to send data through socket. Closes socket on error.
-- Input
-- sock: server socket
-- data: string to be sent
-- Tries to get a pattern from the server and closes socket on error
-- sock: socket connected to the server
-- pattern: pattern to receive
-- Returns
-- err: message in case of error, nil if successfull
-- received pattern on success
-- nil followed by error message on error
-----------------------------------------------------------------------------
function Private.try_send(sock, data)
local err = sock:send(data)
if err then sock:close() end
return err
function Private.try_receive(sock, pattern)
local data, err = sock:receive(pattern)
if not data then sock:close() end
return data, err
end
-----------------------------------------------------------------------------
-- Tries to get a pattern from the server and closes socket on error
-- sock: socket opened to the server
-- ...: pattern to receive
-- Tries to send data to the server and closes socket on error
-- sock: socket connected to the server
-- data: data to send
-- Returns
-- ...: received pattern
-- err: error message if any
-- err: error message if any, nil if successfull
-----------------------------------------------------------------------------
function Private.try_receive(...)
local sock = arg[1]
local data, err = sock.receive(unpack(arg))
if err then sock:close() end
return data, err
function Private.try_send(sock, data)
local sent, err = sock:send(data)
if not sent then sock:close() end
return err
end
-----------------------------------------------------------------------------

222
src/tcp.c Normal file
View File

@ -0,0 +1,222 @@
/*=========================================================================*\
* TCP object
*
* RCS ID: $Id$
\*=========================================================================*/
#include <string.h>
#include <lua.h>
#include <lauxlib.h>
#include "luasocket.h"
#include "aux.h"
#include "inet.h"
#include "tcp.h"
/*=========================================================================*\
* Internal function prototypes
\*=========================================================================*/
static int tcp_global_create(lua_State *L);
static int tcp_meth_connect(lua_State *L);
static int tcp_meth_bind(lua_State *L);
static int tcp_meth_send(lua_State *L);
static int tcp_meth_getsockname(lua_State *L);
static int tcp_meth_getpeername(lua_State *L);
static int tcp_meth_receive(lua_State *L);
static int tcp_meth_accept(lua_State *L);
static int tcp_meth_close(lua_State *L);
static int tcp_meth_timeout(lua_State *L);
/* tcp object methods */
static luaL_reg tcp[] = {
{"connect", tcp_meth_connect},
{"send", tcp_meth_send},
{"receive", tcp_meth_receive},
{"bind", tcp_meth_bind},
{"accept", tcp_meth_accept},
{"setpeername", tcp_meth_connect},
{"setsockname", tcp_meth_bind},
{"getpeername", tcp_meth_getpeername},
{"getsockname", tcp_meth_getsockname},
{"timeout", tcp_meth_timeout},
{"close", tcp_meth_close},
{NULL, NULL}
};
/* functions in library namespace */
static luaL_reg func[] = {
{"tcp", tcp_global_create},
{NULL, NULL}
};
/*-------------------------------------------------------------------------*\
* Initializes module
\*-------------------------------------------------------------------------*/
void tcp_open(lua_State *L)
{
/* create classes */
aux_newclass(L, "tcp{master}", tcp);
aux_newclass(L, "tcp{client}", tcp);
aux_newclass(L, "tcp{server}", tcp);
/* create class groups */
aux_add2group(L, "tcp{client}", "tcp{client, server}");
aux_add2group(L, "tcp{server}", "tcp{client, server}");
aux_add2group(L, "tcp{master}", "tcp{any}");
aux_add2group(L, "tcp{client}", "tcp{any}");
aux_add2group(L, "tcp{server}", "tcp{any}");
/* define library functions */
luaL_openlib(L, LUASOCKET_LIBNAME, func, 0);
lua_pop(L, 1);
}
/*=========================================================================*\
* Lua methods
\*=========================================================================*/
/*-------------------------------------------------------------------------*\
* Just call buffered IO methods
\*-------------------------------------------------------------------------*/
static int tcp_meth_send(lua_State *L)
{
p_tcp tcp = (p_tcp) aux_checkclass(L, "tcp{client}", 1);
return buf_meth_send(L, &tcp->buf);
}
static int tcp_meth_receive(lua_State *L)
{
p_tcp tcp = (p_tcp) aux_checkclass(L, "tcp{client}", 1);
return buf_meth_receive(L, &tcp->buf);
}
/*-------------------------------------------------------------------------*\
* Just call inet methods
\*-------------------------------------------------------------------------*/
static int tcp_meth_getpeername(lua_State *L)
{
p_tcp tcp = (p_tcp) aux_checkclass(L, "tcp{client}", 1);
return inet_meth_getpeername(L, &tcp->sock);
}
static int tcp_meth_getsockname(lua_State *L)
{
p_tcp tcp = (p_tcp) aux_checkgroup(L, "tcp{client, server}", 1);
return inet_meth_getsockname(L, &tcp->sock);
}
/*-------------------------------------------------------------------------*\
* Just call tm methods
\*-------------------------------------------------------------------------*/
static int tcp_meth_timeout(lua_State *L)
{
p_tcp tcp = (p_tcp) aux_checkgroup(L, "tcp{any}", 1);
return tm_meth_timeout(L, &tcp->tm);
}
/*-------------------------------------------------------------------------*\
* Closes socket used by object
\*-------------------------------------------------------------------------*/
static int tcp_meth_close(lua_State *L)
{
p_tcp tcp = (p_tcp) aux_checkgroup(L, "tcp{any}", 1);
sock_destroy(&tcp->sock);
return 0;
}
/*-------------------------------------------------------------------------*\
* Turns a master tcp object into a client object.
\*-------------------------------------------------------------------------*/
static int tcp_meth_connect(lua_State *L)
{
p_tcp tcp = (p_tcp) aux_checkclass(L, "tcp{master}", 1);
const char *address = luaL_checkstring(L, 2);
unsigned short port = (ushort) luaL_checknumber(L, 3);
const char *err = inet_tryconnect(&tcp->sock, address, port);
if (err) {
lua_pushnil(L);
lua_pushstring(L, err);
return 2;
}
/* turn master object into a client object */
aux_setclass(L, "tcp{client}", 1);
lua_pushnumber(L, 1);
return 1;
}
/*-------------------------------------------------------------------------*\
* Turns a master object into a server object
\*-------------------------------------------------------------------------*/
static int tcp_meth_bind(lua_State *L)
{
p_tcp tcp = (p_tcp) aux_checkclass(L, "tcp{master}", 1);
const char *address = luaL_checkstring(L, 2);
unsigned short port = (ushort) luaL_checknumber(L, 3);
int backlog = (int) luaL_optnumber(L, 4, 1);
const char *err = inet_trybind(&tcp->sock, address, port, backlog);
if (err) {
lua_pushnil(L);
lua_pushstring(L, err);
return 2;
}
/* turn master object into a server object */
aux_setclass(L, "tcp{server}", 1);
lua_pushnumber(L, 1);
return 1;
}
/*-------------------------------------------------------------------------*\
* Waits for and returns a client object attempting connection to the
* server object
\*-------------------------------------------------------------------------*/
static int tcp_meth_accept(lua_State *L)
{
struct sockaddr_in addr;
size_t addr_len = sizeof(addr);
p_tcp server = (p_tcp) aux_checkclass(L, "tcp{server}", 1);
p_tm tm = &server->tm;
p_tcp client = lua_newuserdata(L, sizeof(t_tcp));
tm_markstart(tm);
aux_setclass(L, "tcp{client}", -1);
for ( ;; ) {
sock_accept(&server->sock, &client->sock,
(SA *) &addr, &addr_len, tm_get(tm));
if (client->sock == SOCK_INVALID) {
if (tm_get(tm) == 0) {
lua_pushnil(L);
error_push(L, IO_TIMEOUT);
return 2;
}
} else break;
}
/* initialize remaining structure fields */
io_init(&client->io, (p_send) sock_send, (p_recv) sock_recv, &client->sock);
tm_init(&client->tm, -1, -1);
buf_init(&client->buf, &client->io, &client->tm);
return 1;
}
/*=========================================================================*\
* Library functions
\*=========================================================================*/
/*-------------------------------------------------------------------------*\
* Creates a master tcp object
\*-------------------------------------------------------------------------*/
int tcp_global_create(lua_State *L)
{
/* allocate tcp object */
p_tcp tcp = (p_tcp) lua_newuserdata(L, sizeof(t_tcp));
/* set its type as master object */
aux_setclass(L, "tcp{master}", -1);
/* try to allocate a system socket */
const char *err = inet_trycreate(&tcp->sock, SOCK_STREAM);
if (err) { /* get rid of object on stack and push error */
lua_pop(L, 1);
lua_pushnil(L);
lua_pushstring(L, err);
return 2;
}
/* initialize remaining structure fields */
io_init(&tcp->io, (p_send) sock_send, (p_recv) sock_recv, &tcp->sock);
tm_init(&tcp->tm, -1, -1);
buf_init(&tcp->buf, &tcp->io, &tcp->tm);
return 1;
}

20
src/tcp.h Normal file
View File

@ -0,0 +1,20 @@
#ifndef TCP_H
#define TCP_H
#include <lua.h>
#include "buf.h"
#include "tm.h"
#include "sock.h"
typedef struct t_tcp_ {
t_sock sock;
t_io io;
t_buf buf;
t_tm tm;
} t_tcp;
typedef t_tcp *p_tcp;
void tcp_open(lua_State *L);
#endif

View File

@ -1,18 +1,19 @@
/*=========================================================================*\
* Timeout management functions
* Global Lua functions:
* _sleep: (debug mode only)
* _time: (debug mode only)
* _sleep
* _time
*
* RCS ID: $Id$
\*=========================================================================*/
#include <stdio.h>
#include <lua.h>
#include <lauxlib.h>
#include "lspriv.h"
#include "lstm.h"
#include <stdio.h>
#include "luasocket.h"
#include "aux.h"
#include "tm.h"
#ifdef WIN32
#include <windows.h>
@ -28,78 +29,69 @@
static int tm_lua_time(lua_State *L);
static int tm_lua_sleep(lua_State *L);
static luaL_reg func[] = {
{ "time", tm_lua_time },
{ "sleep", tm_lua_sleep },
{ NULL, NULL }
};
/*=========================================================================*\
* Exported functions.
\*=========================================================================*/
/*-------------------------------------------------------------------------*\
* Sets timeout limits
* Input
* tm: timeout control structure
* mode: block or return timeout
* value: timeout value in miliseconds
* Initialize structure
\*-------------------------------------------------------------------------*/
void tm_set(p_tm tm, int tm_block, int tm_return)
void tm_init(p_tm tm, int block, int total)
{
tm->tm_block = tm_block;
tm->tm_return = tm_return;
tm->block = block;
tm->total = total;
}
/*-------------------------------------------------------------------------*\
* Returns timeout limits
* Input
* tm: timeout control structure
* mode: block or return timeout
* value: timeout value in miliseconds
* Set and get timeout limits
\*-------------------------------------------------------------------------*/
void tm_get(p_tm tm, int *tm_block, int *tm_return)
{
if (tm_block) *tm_block = tm->tm_block;
if (tm_return) *tm_return = tm->tm_return;
}
void tm_setblock(p_tm tm, int block)
{ tm->block = block; }
void tm_settotal(p_tm tm, int total)
{ tm->total = total; }
int tm_getblock(p_tm tm)
{ return tm->block; }
int tm_gettotal(p_tm tm)
{ return tm->total; }
int tm_getstart(p_tm tm)
{ return tm->start; }
/*-------------------------------------------------------------------------*\
* Determines how much time we have left for the current io operation
* an IO write operation.
* Determines how much time we have left for the current operation
* Input
* tm: timeout control structure
* Returns
* the number of ms left or -1 if there is no time limit
\*-------------------------------------------------------------------------*/
int tm_getremaining(p_tm tm)
int tm_get(p_tm tm)
{
/* no timeout */
if (tm->tm_block < 0 && tm->tm_return < 0)
if (tm->block < 0 && tm->total < 0)
return -1;
/* there is no block timeout, we use the return timeout */
else if (tm->tm_block < 0)
return MAX(tm->tm_return - tm_gettime() + tm->tm_start, 0);
else if (tm->block < 0)
return MAX(tm->total - tm_gettime() + tm->start, 0);
/* there is no return timeout, we use the block timeout */
else if (tm->tm_return < 0)
return tm->tm_block;
else if (tm->total < 0)
return tm->block;
/* both timeouts are specified */
else return MIN(tm->tm_block,
MAX(tm->tm_return - tm_gettime() + tm->tm_start, 0));
else return MIN(tm->block,
MAX(tm->total - tm_gettime() + tm->start, 0));
}
/*-------------------------------------------------------------------------*\
* Marks the operation start time in sock structure
* Marks the operation start time in structure
* Input
* tm: timeout control structure
\*-------------------------------------------------------------------------*/
void tm_markstart(p_tm tm)
{
tm->tm_start = tm_gettime();
tm->tm_end = tm->tm_start;
}
/*-------------------------------------------------------------------------*\
* Returns the length of the operation in ms
* Input
* tm: timeout control structure
\*-------------------------------------------------------------------------*/
int tm_getelapsed(p_tm tm)
{
return tm->tm_end - tm->tm_start;
tm->start = tm_gettime();
}
/*-------------------------------------------------------------------------*\
@ -125,11 +117,31 @@ int tm_gettime(void)
\*-------------------------------------------------------------------------*/
void tm_open(lua_State *L)
{
(void) L;
lua_pushcfunction(L, tm_lua_time);
priv_newglobal(L, "_time");
lua_pushcfunction(L, tm_lua_sleep);
priv_newglobal(L, "_sleep");
luaL_openlib(L, LUASOCKET_LIBNAME, func, 0);
}
/*-------------------------------------------------------------------------*\
* Sets timeout values for IO operations
* Lua Input: base, time [, mode]
* time: time out value in seconds
* mode: "b" for block timeout, "t" for total timeout. (default: b)
\*-------------------------------------------------------------------------*/
int tm_meth_timeout(lua_State *L, p_tm tm)
{
int ms = lua_isnil(L, 2) ? -1 : (int) (luaL_checknumber(L, 2)*1000.0);
const char *mode = luaL_optstring(L, 3, "b");
switch (*mode) {
case 'b':
tm_setblock(tm, ms);
break;
case 'r': case 't':
tm_settotal(tm, ms);
break;
default:
luaL_argcheck(L, 0, 3, "invalid timeout mode");
break;
}
return 0;
}
/*=========================================================================*\

View File

@ -3,23 +3,29 @@
*
* RCS ID: $Id$
\*=========================================================================*/
#ifndef _TM_H
#define _TM_H
#ifndef TM_H
#define TM_H
typedef struct t_tm_tag {
int tm_return;
int tm_block;
int tm_start;
int tm_end;
#include <lua.h>
/* timeout control structure */
typedef struct t_tm_ {
int total; /* total number of miliseconds for operation */
int block; /* maximum time for blocking calls */
int start; /* time of start of operation */
} t_tm;
typedef t_tm *p_tm;
void tm_set(p_tm tm, int tm_block, int tm_return);
int tm_getremaining(p_tm tm);
int tm_getelapsed(p_tm tm);
int tm_gettime(void);
void tm_get(p_tm tm, int *tm_block, int *tm_return);
void tm_markstart(p_tm tm);
void tm_open(lua_State *L);
void tm_init(p_tm tm, int block, int total);
void tm_setblock(p_tm tm, int block);
void tm_settotal(p_tm tm, int total);
int tm_getblock(p_tm tm);
int tm_gettotal(p_tm tm);
void tm_markstart(p_tm tm);
int tm_getstart(p_tm tm);
int tm_get(p_tm tm);
int tm_gettime(void);
int tm_meth_timeout(lua_State *L, p_tm tm);
#endif

438
src/udp.c
View File

@ -1,15 +1,5 @@
/*=========================================================================*\
* UDP class: inherits from Socked and Internet domain classes and provides
* all the functionality for UDP objects.
* Lua methods:
* send: using compat module
* sendto: using compat module
* receive: using compat module
* receivefrom: using compat module
* setpeername: using internet module
* setsockname: using internet module
* Global Lua functions:
* udp: creates the udp object
* UDP object
*
* RCS ID: $Id$
\*=========================================================================*/
@ -18,282 +8,256 @@
#include <lua.h>
#include <lauxlib.h>
#include "lsinet.h"
#include "lsudp.h"
#include "lscompat.h"
#include "lsselect.h"
#include "luasocket.h"
#include "aux.h"
#include "inet.h"
#include "udp.h"
/*=========================================================================*\
* Internal function prototypes.
* Internal function prototypes
\*=========================================================================*/
static int udp_lua_send(lua_State *L);
static int udp_lua_sendto(lua_State *L);
static int udp_lua_receive(lua_State *L);
static int udp_lua_receivefrom(lua_State *L);
static int udp_lua_setpeername(lua_State *L);
static int udp_lua_setsockname(lua_State *L);
static int udp_global_create(lua_State *L);
static int udp_meth_send(lua_State *L);
static int udp_meth_sendto(lua_State *L);
static int udp_meth_receive(lua_State *L);
static int udp_meth_receivefrom(lua_State *L);
static int udp_meth_getsockname(lua_State *L);
static int udp_meth_getpeername(lua_State *L);
static int udp_meth_setsockname(lua_State *L);
static int udp_meth_setpeername(lua_State *L);
static int udp_meth_close(lua_State *L);
static int udp_meth_timeout(lua_State *L);
static int udp_global_udp(lua_State *L);
static struct luaL_reg funcs[] = {
{"send", udp_lua_send},
{"sendto", udp_lua_sendto},
{"receive", udp_lua_receive},
{"receivefrom", udp_lua_receivefrom},
{"setpeername", udp_lua_setpeername},
{"setsockname", udp_lua_setsockname},
/* udp object methods */
static luaL_reg udp[] = {
{"setpeername", udp_meth_setpeername},
{"setsockname", udp_meth_setsockname},
{"getsockname", udp_meth_getsockname},
{"getpeername", udp_meth_getpeername},
{"send", udp_meth_send},
{"sendto", udp_meth_sendto},
{"receive", udp_meth_receive},
{"receivefrom", udp_meth_receivefrom},
{"timeout", udp_meth_timeout},
{"close", udp_meth_close},
{NULL, NULL}
};
/* functions in library namespace */
static luaL_reg func[] = {
{"udp", udp_global_create},
{NULL, NULL}
};
/*=========================================================================*\
* Exported functions
\*=========================================================================*/
/*-------------------------------------------------------------------------*\
* Initializes module
\*-------------------------------------------------------------------------*/
void udp_open(lua_State *L)
{
unsigned int i;
priv_newclass(L, UDP_CLASS);
udp_inherit(L, UDP_CLASS);
/* declare global functions */
lua_pushcfunction(L, udp_global_udp);
priv_newglobal(L, "udp");
for (i = 0; i < sizeof(funcs)/sizeof(funcs[0]); i++)
priv_newglobalmethod(L, funcs[i].name);
/* make class selectable */
select_addclass(L, UDP_CLASS);
}
/*-------------------------------------------------------------------------*\
* Hook object methods to methods table.
\*-------------------------------------------------------------------------*/
void udp_inherit(lua_State *L, cchar *lsclass)
{
unsigned int i;
inet_inherit(L, lsclass);
for (i = 0; i < sizeof(funcs)/sizeof(funcs[0]); i++) {
lua_pushcfunction(L, funcs[i].func);
priv_setmethod(L, lsclass, funcs[i].name);
}
}
/*-------------------------------------------------------------------------*\
* Initializes socket structure
\*-------------------------------------------------------------------------*/
void udp_construct(lua_State *L, p_udp udp)
{
inet_construct(L, (p_inet) udp);
udp->udp_connected = 0;
}
/*-------------------------------------------------------------------------*\
* Creates a socket structure and initializes it. A socket object is
* left in the Lua stack.
* Returns
* pointer to allocated structure
\*-------------------------------------------------------------------------*/
p_udp udp_push(lua_State *L)
{
p_udp udp = (p_udp) lua_newuserdata(L, sizeof(t_udp));
priv_setclass(L, UDP_CLASS);
udp_construct(L, udp);
return udp;
/* create classes */
aux_newclass(L, "udp{connected}", udp);
aux_newclass(L, "udp{unconnected}", udp);
/* create class groups */
aux_add2group(L, "udp{connected}", "udp{any}");
aux_add2group(L, "udp{unconnected}", "udp{any}");
/* define library functions */
luaL_openlib(L, LUASOCKET_LIBNAME, func, 0);
lua_pop(L, 1);
}
/*=========================================================================*\
* Socket table constructors
* Lua methods
\*=========================================================================*/
/*-------------------------------------------------------------------------*\
* Creates a udp socket object and returns it to the Lua script.
* Lua Input: [options]
* options: socket options table
* Lua Returns
* On success: udp socket
* On error: nil, followed by an error message
* Send data through connected udp socket
\*-------------------------------------------------------------------------*/
static int udp_global_udp(lua_State *L)
static int udp_meth_send(lua_State *L)
{
int oldtop = lua_gettop(L);
p_udp udp = udp_push(L);
cchar *err = inet_trysocket((p_inet) udp, SOCK_DGRAM);
if (err) {
lua_pushnil(L);
lua_pushstring(L, err);
return 2;
}
if (oldtop < 1) return 1;
err = compat_trysetoptions(L, udp->fd);
if (err) {
lua_pushnil(L);
lua_pushstring(L, err);
return 2;
}
return 1;
}
/*=========================================================================*\
* Socket table methods
\*=========================================================================*/
/*-------------------------------------------------------------------------*\
* Receives data from a UDP socket
* Lua Input: sock [, wanted]
* sock: client socket created by the connect function
* wanted: the number of bytes expected (default: LUASOCKET_UDPBUFFERSIZE)
* Lua Returns
* On success: datagram received
* On error: nil, followed by an error message
\*-------------------------------------------------------------------------*/
static int udp_lua_receive(lua_State *L)
{
p_udp udp = (p_udp) lua_touserdata(L, 1);
char buffer[UDP_DATAGRAMSIZE];
size_t got, wanted = (size_t) luaL_optnumber(L, 2, sizeof(buffer));
p_udp udp = (p_udp) aux_checkclass(L, "udp{connected}", 1);
p_tm tm = &udp->tm;
size_t count, sent = 0;
int err;
p_tm tm = &udp->base_tm;
wanted = MIN(wanted, sizeof(buffer));
const char *data = luaL_checklstring(L, 2, &count);
tm_markstart(tm);
err = compat_recv(udp->fd, buffer, wanted, &got, tm_getremaining(tm));
if (err == PRIV_CLOSED) err = PRIV_REFUSED;
if (err != PRIV_DONE) lua_pushnil(L);
else lua_pushlstring(L, buffer, got);
priv_pusherror(L, err);
err = sock_send(&udp->sock, data, count, &sent, tm_get(tm));
if (err == IO_DONE) lua_pushnumber(L, sent);
else lua_pushnil(L);
error_push(L, err);
return 2;
}
/*-------------------------------------------------------------------------*\
* Receives a datagram from a UDP socket
* Lua Input: sock [, wanted]
* sock: client socket created by the connect function
* wanted: the number of bytes expected (default: LUASOCKET_UDPBUFFERSIZE)
* Lua Returns
* On success: datagram received, ip and port of sender
* On error: nil, followed by an error message
* Send data through unconnected udp socket
\*-------------------------------------------------------------------------*/
static int udp_lua_receivefrom(lua_State *L)
static int udp_meth_sendto(lua_State *L)
{
p_udp udp = (p_udp) lua_touserdata(L, 1);
p_tm tm = &udp->base_tm;
struct sockaddr_in peer;
size_t peer_len = sizeof(peer);
char buffer[UDP_DATAGRAMSIZE];
size_t wanted = (size_t) luaL_optnumber(L, 2, sizeof(buffer));
size_t got;
p_udp udp = (p_udp) aux_checkclass(L, "udp{unconnected}", 1);
size_t count, sent = 0;
const char *data = luaL_checklstring(L, 2, &count);
const char *ip = luaL_checkstring(L, 3);
ushort port = (ushort) luaL_checknumber(L, 4);
p_tm tm = &udp->tm;
struct sockaddr_in addr;
int err;
if (udp->udp_connected) luaL_error(L, "receivefrom on connected socket");
memset(&addr, 0, sizeof(addr));
if (!inet_aton(ip, &addr.sin_addr))
luaL_argerror(L, 3, "invalid ip address");
addr.sin_family = AF_INET;
addr.sin_port = htons(port);
tm_markstart(tm);
wanted = MIN(wanted, sizeof(buffer));
err = compat_recvfrom(udp->fd, buffer, wanted, &got, tm_getremaining(tm),
(SA *) &peer, &peer_len);
if (err == PRIV_CLOSED) err = PRIV_REFUSED;
if (err == PRIV_DONE) {
err = sock_sendto(&udp->sock, data, count, &sent,
(SA *) &addr, sizeof(addr), tm_get(tm));
if (err == IO_DONE) lua_pushnumber(L, sent);
else lua_pushnil(L);
error_push(L, err == IO_CLOSED ? IO_REFUSED : err);
return 2;
}
/*-------------------------------------------------------------------------*\
* Receives data from a UDP socket
\*-------------------------------------------------------------------------*/
static int udp_meth_receive(lua_State *L)
{
p_udp udp = (p_udp) aux_checkgroup(L, "udp{any}", 1);
char buffer[UDP_DATAGRAMSIZE];
size_t got, count = (size_t) luaL_optnumber(L, 2, sizeof(buffer));
int err;
p_tm tm = &udp->tm;
count = MIN(count, sizeof(buffer));
tm_markstart(tm);
err = sock_recv(&udp->sock, buffer, count, &got, tm_get(tm));
if (err == IO_DONE) lua_pushlstring(L, buffer, got);
else lua_pushnil(L);
error_push(L, err);
return 2;
}
/*-------------------------------------------------------------------------*\
* Receives data and sender from a UDP socket
\*-------------------------------------------------------------------------*/
static int udp_meth_receivefrom(lua_State *L)
{
p_udp udp = (p_udp) aux_checkclass(L, "udp{unconnected}", 1);
struct sockaddr_in addr;
size_t addr_len = sizeof(addr);
char buffer[UDP_DATAGRAMSIZE];
size_t got, count = (size_t) luaL_optnumber(L, 2, sizeof(buffer));
int err;
p_tm tm = &udp->tm;
tm_markstart(tm);
count = MIN(count, sizeof(buffer));
err = sock_recvfrom(&udp->sock, buffer, count, &got,
(SA *) &addr, &addr_len, tm_get(tm));
if (err == IO_DONE) {
lua_pushlstring(L, buffer, got);
lua_pushstring(L, inet_ntoa(peer.sin_addr));
lua_pushnumber(L, ntohs(peer.sin_port));
lua_pushstring(L, inet_ntoa(addr.sin_addr));
lua_pushnumber(L, ntohs(addr.sin_port));
return 3;
} else {
lua_pushnil(L);
priv_pusherror(L, err);
error_push(L, err);
return 2;
}
}
/*-------------------------------------------------------------------------*\
* Send data through a connected UDP socket
* Lua Input: sock, data
* sock: udp socket
* data: data to be sent
* Lua Returns
* On success: nil, followed by the total number of bytes sent
* On error: error message
* Just call inet methods
\*-------------------------------------------------------------------------*/
static int udp_lua_send(lua_State *L)
static int udp_meth_getpeername(lua_State *L)
{
p_udp udp = (p_udp) lua_touserdata(L, 1);
p_tm tm = &udp->base_tm;
size_t wanted, sent = 0;
int err;
cchar *data = luaL_checklstring(L, 2, &wanted);
if (!udp->udp_connected) luaL_error(L, "send on unconnected socket");
tm_markstart(tm);
err = compat_send(udp->fd, data, wanted, &sent, tm_getremaining(tm));
priv_pusherror(L, err == PRIV_CLOSED ? PRIV_REFUSED : err);
lua_pushnumber(L, sent);
return 2;
p_udp udp = (p_udp) aux_checkclass(L, "udp{connected}", 1);
return inet_meth_getpeername(L, &udp->sock);
}
static int udp_meth_getsockname(lua_State *L)
{
p_udp udp = (p_udp) aux_checkgroup(L, "udp{any}", 1);
return inet_meth_getsockname(L, &udp->sock);
}
/*-------------------------------------------------------------------------*\
* Send data through a unconnected UDP socket
* Lua Input: sock, data, ip, port
* sock: udp socket
* data: data to be sent
* ip: ip address of target
* port: port in target
* Lua Returns
* On success: nil, followed by the total number of bytes sent
* On error: error message
* Just call tm methods
\*-------------------------------------------------------------------------*/
static int udp_lua_sendto(lua_State *L)
static int udp_meth_timeout(lua_State *L)
{
p_udp udp = (p_udp) lua_touserdata(L, 1);
size_t wanted, sent = 0;
cchar *data = luaL_checklstring(L, 2, &wanted);
cchar *ip = luaL_checkstring(L, 3);
ushort port = (ushort) luaL_checknumber(L, 4);
p_tm tm = &udp->base_tm;
struct sockaddr_in peer;
int err;
if (udp->udp_connected) luaL_error(L, "sendto on connected socket");
memset(&peer, 0, sizeof(peer));
if (!inet_aton(ip, &peer.sin_addr)) luaL_error(L, "invalid ip address");
peer.sin_family = AF_INET;
peer.sin_port = htons(port);
tm_markstart(tm);
err = compat_sendto(udp->fd, data, wanted, &sent, tm_getremaining(tm),
(SA *) &peer, sizeof(peer));
priv_pusherror(L, err == PRIV_CLOSED ? PRIV_REFUSED : err);
lua_pushnumber(L, sent);
return 2;
p_udp udp = (p_udp) aux_checkgroup(L, "udp{any}", 1);
return tm_meth_timeout(L, &udp->tm);
}
/*-------------------------------------------------------------------------*\
* Associates a local address to an UDP socket
* Lua Input: address, port
* address: host name or ip address to bind to
* port: port to bind to
* Lua Returns
* On success: nil
* On error: error message
* Turns a master udp object into a client object.
\*-------------------------------------------------------------------------*/
static int udp_lua_setsockname(lua_State * L)
static int udp_meth_setpeername(lua_State *L)
{
p_udp udp = (p_udp) lua_touserdata(L, 1);
cchar *address = luaL_checkstring(L, 2);
ushort port = (ushort) luaL_checknumber(L, 3);
cchar *err = inet_trybind((p_inet) udp, address, port);
if (err) lua_pushstring(L, err);
else lua_pushnil(L);
return 1;
}
/*-------------------------------------------------------------------------*\
* Sets a peer for a UDP socket
* Lua Input: address, port
* address: remote host name
* port: remote host port
* Lua Returns
* On success: nil
* On error: error message
\*-------------------------------------------------------------------------*/
static int udp_lua_setpeername(lua_State *L)
{
p_udp udp = (p_udp) lua_touserdata(L, 1);
cchar *address = luaL_checkstring(L, 2);
ushort port = (ushort) luaL_checknumber(L, 3);
cchar *err = inet_tryconnect((p_inet) udp, address, port);
if (!err) {
udp->udp_connected = 1;
p_udp udp = (p_udp) aux_checkclass(L, "udp{unconnected}", 1);
const char *address = luaL_checkstring(L, 2);
int connecting = strcmp(address, "*");
unsigned short port = connecting ?
(ushort) luaL_checknumber(L, 3) : (ushort) luaL_optnumber(L, 3, 0);
const char *err = inet_tryconnect(&udp->sock, address, port);
if (err) {
lua_pushnil(L);
} else lua_pushstring(L, err);
lua_pushstring(L, err);
return 2;
}
/* change class to connected or unconnected depending on address */
if (connecting) aux_setclass(L, "udp{connected}", 1);
else aux_setclass(L, "udp{unconnected}", 1);
lua_pushnumber(L, 1);
return 1;
}
/*-------------------------------------------------------------------------*\
* Closes socket used by object
\*-------------------------------------------------------------------------*/
static int udp_meth_close(lua_State *L)
{
p_udp udp = (p_udp) aux_checkgroup(L, "udp{any}", 1);
sock_destroy(&udp->sock);
return 0;
}
/*-------------------------------------------------------------------------*\
* Turns a master object into a server object
\*-------------------------------------------------------------------------*/
static int udp_meth_setsockname(lua_State *L)
{
p_udp udp = (p_udp) aux_checkclass(L, "udp{master}", 1);
const char *address = luaL_checkstring(L, 2);
unsigned short port = (ushort) luaL_checknumber(L, 3);
const char *err = inet_trybind(&udp->sock, address, port, -1);
if (err) {
lua_pushnil(L);
lua_pushstring(L, err);
return 2;
}
lua_pushnumber(L, 1);
return 1;
}
/*=========================================================================*\
* Library functions
\*=========================================================================*/
/*-------------------------------------------------------------------------*\
* Creates a master udp object
\*-------------------------------------------------------------------------*/
int udp_global_create(lua_State *L)
{
/* allocate udp object */
p_udp udp = (p_udp) lua_newuserdata(L, sizeof(t_udp));
/* set its type as master object */
aux_setclass(L, "udp{unconnected}", -1);
/* try to allocate a system socket */
const char *err = inet_trycreate(&udp->sock, SOCK_DGRAM);
if (err) {
/* get rid of object on stack and push error */
lua_pop(L, 1);
lua_pushnil(L);
lua_pushstring(L, err);
return 2;
}
/* initialize timeout management */
tm_init(&udp->tm, -1, -1);
return 1;
}

View File

@ -1,30 +1,19 @@
/*=========================================================================*\
* UDP class: inherits from Socked and Internet domain classes and provides
* all the functionality for UDP objects.
*
* RCS ID: $Id$
\*=========================================================================*/
#ifndef UDP_H_
#define UDP_H_
#ifndef UDP_H
#define UDP_H
#include "lsinet.h"
#include <lua.h>
#define UDP_CLASS "luasocket(UDP socket)"
#include "tm.h"
#include "sock.h"
#define UDP_DATAGRAMSIZE 576
#define UDP_FIELDS \
INET_FIELDS; \
int udp_connected
typedef struct t_udp_tag {
UDP_FIELDS;
typedef struct t_udp_ {
t_sock sock;
t_tm tm;
} t_udp;
typedef t_udp *p_udp;
void udp_inherit(lua_State *L, cchar *lsclass);
void udp_construct(lua_State *L, p_udp udp);
void udp_open(lua_State *L);
p_udp udp_push(lua_State *L);
#endif

View File

@ -1,5 +1,5 @@
/*=========================================================================*\
* Network compatibilization module: Unix version
* Socket compatibilization module for Unix
*
* RCS ID: $Id$
\*=========================================================================*/
@ -7,20 +7,20 @@
#include <lauxlib.h>
#include <string.h>
#include "lscompat.h"
#include "sock.h"
/*=========================================================================*\
* Internal function prototypes
\*=========================================================================*/
static cchar *try_setoption(lua_State *L, COMPAT_FD sock);
static cchar *try_setbooloption(lua_State *L, COMPAT_FD sock, int name);
static const char *try_setoption(lua_State *L, p_sock ps);
static const char *try_setbooloption(lua_State *L, p_sock ps, int name);
/*=========================================================================*\
* Exported functions.
\*=========================================================================*/
int compat_open(lua_State *L)
int sock_open(lua_State *L)
{
/* Instals a handler to ignore sigpipe. */
/* instals a handler to ignore sigpipe. */
struct sigaction new;
memset(&new, 0, sizeof(new));
new.sa_handler = SIG_IGN;
@ -28,143 +28,178 @@ int compat_open(lua_State *L)
return 1;
}
COMPAT_FD compat_accept(COMPAT_FD s, struct sockaddr *addr,
size_t *len, int deadline)
void sock_destroy(p_sock ps)
{
struct timeval tv;
fd_set fds;
tv.tv_sec = deadline / 1000;
tv.tv_usec = (deadline % 1000) * 1000;
FD_ZERO(&fds);
FD_SET(s, &fds);
select(s+1, &fds, NULL, NULL, deadline >= 0 ? &tv : NULL);
return accept(s, addr, len);
close(*ps);
}
int compat_send(COMPAT_FD c, cchar *data, size_t count, size_t *sent,
int deadline)
const char *sock_create(p_sock ps, int domain, int type, int protocol)
{
t_sock sock = socket(domain, type, protocol);
if (sock == SOCK_INVALID) return sock_createstrerror();
*ps = sock;
sock_setnonblocking(ps);
sock_setreuseaddr(ps);
return NULL;
}
const char *sock_connect(p_sock ps, SA *addr, size_t addr_len)
{
if (connect(*ps, addr, addr_len) < 0) return sock_connectstrerror();
else return NULL;
}
const char *sock_bind(p_sock ps, SA *addr, size_t addr_len)
{
if (bind(*ps, addr, addr_len) < 0) return sock_bindstrerror();
else return NULL;
}
void sock_listen(p_sock ps, int backlog)
{
listen(*ps, backlog);
}
void sock_accept(p_sock ps, p_sock pa, SA *addr, size_t *addr_len, int timeout)
{
t_sock sock = *ps;
struct timeval tv;
fd_set fds;
tv.tv_sec = timeout / 1000;
tv.tv_usec = (timeout % 1000) * 1000;
FD_ZERO(&fds);
FD_SET(sock, &fds);
select(sock+1, &fds, NULL, NULL, timeout >= 0 ? &tv : NULL);
*pa = accept(sock, addr, addr_len);
}
int sock_send(p_sock ps, const char *data, size_t count, size_t *sent,
int timeout)
{
t_sock sock = *ps;
struct timeval tv;
fd_set fds;
ssize_t put = 0;
int err;
int ret;
tv.tv_sec = deadline / 1000;
tv.tv_usec = (deadline % 1000) * 1000;
tv.tv_sec = timeout / 1000;
tv.tv_usec = (timeout % 1000) * 1000;
FD_ZERO(&fds);
FD_SET(c, &fds);
ret = select(c+1, NULL, &fds, NULL, deadline >= 0 ? &tv : NULL);
FD_SET(sock, &fds);
ret = select(sock+1, NULL, &fds, NULL, timeout >= 0 ? &tv : NULL);
if (ret > 0) {
put = write(c, data, count);
put = write(sock, data, count);
if (put <= 0) {
err = PRIV_CLOSED;
err = IO_CLOSED;
#ifdef __CYGWIN__
/* this is for CYGWIN, which is like Unix but has Win32 bugs */
if (errno == EWOULDBLOCK) err = PRIV_DONE;
if (errno == EWOULDBLOCK) err = IO_DONE;
#endif
*sent = 0;
} else {
*sent = put;
err = PRIV_DONE;
err = IO_DONE;
}
return err;
} else {
*sent = 0;
return PRIV_TIMEOUT;
return IO_TIMEOUT;
}
}
int compat_sendto(COMPAT_FD c, cchar *data, size_t count, size_t *sent,
int deadline, SA *addr, size_t len)
int sock_sendto(p_sock ps, const char *data, size_t count, size_t *sent,
SA *addr, size_t addr_len, int timeout)
{
t_sock sock = *ps;
struct timeval tv;
fd_set fds;
ssize_t put = 0;
int err;
int ret;
tv.tv_sec = deadline / 1000;
tv.tv_usec = (deadline % 1000) * 1000;
tv.tv_sec = timeout / 1000;
tv.tv_usec = (timeout % 1000) * 1000;
FD_ZERO(&fds);
FD_SET(c, &fds);
ret = select(c+1, NULL, &fds, NULL, deadline >= 0 ? &tv : NULL);
FD_SET(sock, &fds);
ret = select(sock+1, NULL, &fds, NULL, timeout >= 0 ? &tv : NULL);
if (ret > 0) {
put = sendto(c, data, count, 0, addr, len);
put = sendto(sock, data, count, 0, addr, addr_len);
if (put <= 0) {
err = PRIV_CLOSED;
err = IO_CLOSED;
#ifdef __CYGWIN__
/* this is for CYGWIN, which is like Unix but has Win32 bugs */
if (sent < 0 && errno == EWOULDBLOCK) err = PRIV_DONE;
if (sent < 0 && errno == EWOULDBLOCK) err = IO_DONE;
#endif
*sent = 0;
} else {
*sent = put;
err = PRIV_DONE;
err = IO_DONE;
}
return err;
} else {
*sent = 0;
return PRIV_TIMEOUT;
return IO_TIMEOUT;
}
}
int compat_recv(COMPAT_FD c, char *data, size_t count, size_t *got,
int deadline)
int sock_recv(p_sock ps, char *data, size_t count, size_t *got, int timeout)
{
t_sock sock = *ps;
struct timeval tv;
fd_set fds;
int ret;
ssize_t taken = 0;
tv.tv_sec = deadline / 1000;
tv.tv_usec = (deadline % 1000) * 1000;
tv.tv_sec = timeout / 1000;
tv.tv_usec = (timeout % 1000) * 1000;
FD_ZERO(&fds);
FD_SET(c, &fds);
ret = select(c+1, &fds, NULL, NULL, deadline >= 0 ? &tv : NULL);
FD_SET(sock, &fds);
ret = select(sock+1, &fds, NULL, NULL, timeout >= 0 ? &tv : NULL);
if (ret > 0) {
taken = read(c, data, count);
taken = read(sock, data, count);
if (taken <= 0) {
*got = 0;
return PRIV_CLOSED;
return IO_CLOSED;
} else {
*got = taken;
return PRIV_DONE;
return IO_DONE;
}
} else {
*got = 0;
return PRIV_TIMEOUT;
return IO_TIMEOUT;
}
}
int compat_recvfrom(COMPAT_FD c, char *data, size_t count, size_t *got,
int deadline, SA *addr, size_t *len)
int sock_recvfrom(p_sock ps, char *data, size_t count, size_t *got,
SA *addr, size_t *addr_len, int timeout)
{
t_sock sock = *ps;
struct timeval tv;
fd_set fds;
int ret;
ssize_t taken = 0;
tv.tv_sec = deadline / 1000;
tv.tv_usec = (deadline % 1000) * 1000;
tv.tv_sec = timeout / 1000;
tv.tv_usec = (timeout % 1000) * 1000;
FD_ZERO(&fds);
FD_SET(c, &fds);
ret = select(c+1, &fds, NULL, NULL, deadline >= 0 ? &tv : NULL);
FD_SET(sock, &fds);
ret = select(sock+1, &fds, NULL, NULL, timeout >= 0 ? &tv : NULL);
if (ret > 0) {
taken = recvfrom(c, data, count, 0, addr, len);
taken = recvfrom(sock, data, count, 0, addr, addr_len);
if (taken <= 0) {
*got = 0;
return PRIV_CLOSED;
return IO_CLOSED;
} else {
*got = taken;
return PRIV_DONE;
return IO_DONE;
}
} else {
*got = 0;
return PRIV_TIMEOUT;
return IO_TIMEOUT;
}
}
/*-------------------------------------------------------------------------*\
* Returns a string describing the last host manipulation error.
\*-------------------------------------------------------------------------*/
const char *compat_hoststrerror(void)
const char *sock_hoststrerror(void)
{
switch (h_errno) {
case HOST_NOT_FOUND: return "host not found";
@ -178,7 +213,7 @@ const char *compat_hoststrerror(void)
/*-------------------------------------------------------------------------*\
* Returns a string describing the last socket manipulation error.
\*-------------------------------------------------------------------------*/
const char *compat_socketstrerror(void)
const char *sock_createstrerror(void)
{
switch (errno) {
case EACCES: return "access denied";
@ -192,7 +227,7 @@ const char *compat_socketstrerror(void)
/*-------------------------------------------------------------------------*\
* Returns a string describing the last bind command error.
\*-------------------------------------------------------------------------*/
const char *compat_bindstrerror(void)
const char *sock_bindstrerror(void)
{
switch (errno) {
case EBADF: return "invalid descriptor";
@ -209,7 +244,7 @@ const char *compat_bindstrerror(void)
/*-------------------------------------------------------------------------*\
* Returns a string describing the last connect error.
\*-------------------------------------------------------------------------*/
const char *compat_connectstrerror(void)
const char *sock_connectstrerror(void)
{
switch (errno) {
case EBADF: return "invalid descriptor";
@ -229,40 +264,30 @@ const char *compat_connectstrerror(void)
* Input
* sock: socket descriptor
\*-------------------------------------------------------------------------*/
void compat_setreuseaddr(COMPAT_FD sock)
void sock_setreuseaddr(p_sock ps)
{
int val = 1;
setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, (char *)&val, sizeof(val));
}
COMPAT_FD compat_socket(int domain, int type, int protocol)
{
COMPAT_FD sock = socket(domain, type, protocol);
if (sock != COMPAT_INVALIDFD) {
compat_setnonblocking(sock);
compat_setreuseaddr(sock);
}
return sock;
setsockopt(*ps, SOL_SOCKET, SO_REUSEADDR, (char *)&val, sizeof(val));
}
/*-------------------------------------------------------------------------*\
* Put socket into blocking mode.
\*-------------------------------------------------------------------------*/
void compat_setblocking(COMPAT_FD sock)
void sock_setblocking(p_sock ps)
{
int flags = fcntl(sock, F_GETFL, 0);
int flags = fcntl(*ps, F_GETFL, 0);
flags &= (~(O_NONBLOCK));
fcntl(sock, F_SETFL, flags);
fcntl(*ps, F_SETFL, flags);
}
/*-------------------------------------------------------------------------*\
* Put socket into non-blocking mode.
\*-------------------------------------------------------------------------*/
void compat_setnonblocking(COMPAT_FD sock)
void sock_setnonblocking(p_sock ps)
{
int flags = fcntl(sock, F_GETFL, 0);
int flags = fcntl(*ps, F_GETFL, 0);
flags |= O_NONBLOCK;
fcntl(sock, F_SETFL, flags);
fcntl(*ps, F_SETFL, flags);
}
/*-------------------------------------------------------------------------*\
@ -273,54 +298,50 @@ void compat_setnonblocking(COMPAT_FD sock)
* Returns
* NULL if successfull, error message on error
\*-------------------------------------------------------------------------*/
cchar *compat_trysetoptions(lua_State *L, COMPAT_FD sock)
const char *sock_trysetoptions(lua_State *L, p_sock ps)
{
if (!lua_istable(L, 1)) luaL_argerror(L, 1, "invalid options table");
lua_pushnil(L);
while (lua_next(L, 1)) {
cchar *err = try_setoption(L, sock);
const char *err = try_setoption(L, ps);
lua_pop(L, 1);
if (err) return err;
}
return NULL;
}
/*=========================================================================*\
* Internal functions.
\*=========================================================================*/
static cchar *try_setbooloption(lua_State *L, COMPAT_FD sock, int name)
{
int bool, res;
if (!lua_isnumber(L, -1)) luaL_error(L, "invalid option value");
bool = (int) lua_tonumber(L, -1);
res = setsockopt(sock, SOL_SOCKET, name, (char *) &bool, sizeof(bool));
if (res < 0) return "error setting option";
else return NULL;
}
/*-------------------------------------------------------------------------*\
* Set socket options from a table on top of Lua stack.
* Supports SO_KEEPALIVE, SO_DONTROUTE, SO_BROADCAST, and SO_LINGER options.
* Supports SO_KEEPALIVE, SO_DONTROUTE, and SO_BROADCAST options.
* Input
* L: Lua state to use
* sock: socket descriptor
* sock: socket
* Returns
* 1 if successful, 0 otherwise
\*-------------------------------------------------------------------------*/
static cchar *try_setoption(lua_State *L, COMPAT_FD sock)
static const char *try_setoption(lua_State *L, p_sock ps)
{
static cchar *options[] = {
"SO_KEEPALIVE", "SO_DONTROUTE", "SO_BROADCAST", "SO_LINGER", NULL
static const char *options[] = {
"SO_KEEPALIVE", "SO_DONTROUTE", "SO_BROADCAST", NULL
};
cchar *option = lua_tostring(L, -2);
const char *option = lua_tostring(L, -2);
if (!lua_isstring(L, -2)) return "invalid option";
switch (luaL_findstring(option, options)) {
case 0: return try_setbooloption(L, sock, SO_KEEPALIVE);
case 1: return try_setbooloption(L, sock, SO_DONTROUTE);
case 2: return try_setbooloption(L, sock, SO_BROADCAST);
case 3: return "SO_LINGER is deprecated";
case 0: return try_setbooloption(L, ps, SO_KEEPALIVE);
case 1: return try_setbooloption(L, ps, SO_DONTROUTE);
case 2: return try_setbooloption(L, ps, SO_BROADCAST);
default: return "unsupported option";
}
}
/*=========================================================================*\
* Internal functions.
\*=========================================================================*/
static const char *try_setbooloption(lua_State *L, p_sock ps, int name)
{
int bool, res;
if (!lua_isnumber(L, -1)) luaL_error(L, "invalid option value");
bool = (int) lua_tonumber(L, -1);
res = setsockopt(*ps, SOL_SOCKET, name, (char *) &bool, sizeof(bool));
if (res < 0) return "error setting option";
else return NULL;
}

View File

@ -1,10 +1,10 @@
/*=========================================================================*\
* Network compatibilization module: Unix version
* Socket compatibilization module for Unix
*
* RCS ID: $Id$
\*=========================================================================*/
#ifndef UNIX_H_
#define UNIX_H_
#ifndef UNIX_H
#define UNIX_H
/*=========================================================================*\
* BSD include files
@ -31,13 +31,9 @@
#include <netinet/in.h>
#include <arpa/inet.h>
#define COMPAT_FD int
#define COMPAT_INVALIDFD (-1)
typedef int t_sock;
typedef t_sock *p_sock;
#define compat_bind bind
#define compat_connect connect
#define compat_listen listen
#define compat_close close
#define compat_select select
#define SOCK_INVALID (-1)
#endif /* UNIX_H_ */
#endif /* UNIX_H */

View File

@ -1,5 +1,3 @@
dofile("noglobals.lua")
local similar = function(s1, s2)
return
string.lower(string.gsub(s1, "%s", "")) ==
@ -34,7 +32,7 @@ end
local index, err, saved, back, expected
local t = socket._time()
local t = socket.time()
index = readfile("test/index.html")
@ -112,4 +110,4 @@ back, err = socket.ftp.get("ftp://localhost/index.wrong.html;type=a")
check(err, err)
print("passed all tests")
print(string.format("done in %.2fs", socket._time() - t))
print(string.format("done in %.2fs", socket.time() - t))

View File

@ -3,9 +3,6 @@
-- needs ScriptAlias from /home/c/diego/tec/luasocket/test/cgi
-- to /luasocket-test-cgi
-- needs AllowOverride AuthConfig on /home/c/diego/tec/luasocket/test/auth
dofile("noglobals.lua")
local similar = function(s1, s2)
return string.lower(string.gsub(s1 or "", "%s", "")) ==
string.lower(string.gsub(s2 or "", "%s", ""))
@ -27,27 +24,27 @@ end
local check = function (v, e)
if v then print("ok")
else %fail(e) end
else fail(e) end
end
local check_request = function(request, expect, ignore)
local response = socket.http.request(request)
for i,v in response do
if not ignore[i] then
if v ~= expect[i] then %fail(i .. " differs!") end
if v ~= expect[i] then fail(i .. " differs!") end
end
end
for i,v in expect do
if not ignore[i] then
if v ~= response[i] then %fail(i .. " differs!") end
if v ~= response[i] then fail(i .. " differs!") end
end
end
print("ok")
end
local request, response, ignore, expect, index, prefix, cgiprefix
local host, request, response, ignore, expect, index, prefix, cgiprefix
local t = socket._time()
local t = socket.time()
host = host or "localhost"
prefix = prefix or "/luasocket"
@ -310,4 +307,4 @@ check(response and response.headers)
print("passed all tests")
print(string.format("done in %.2fs", socket._time() - t))
print(string.format("done in %.2fs", socket.time() - t))

View File

@ -11,7 +11,7 @@ local files = {
"/var/spool/mail/luasock3",
}
local t = socket._time()
local t = socket.time()
local err
dofile("mbox.lua")
@ -106,7 +106,7 @@ local insert = function(sent, message)
end
local mark = function()
local time = socket._time()
local time = socket.time()
return { time = time }
end
@ -116,11 +116,11 @@ local wait = function(sentinel, n)
while 1 do
local mbox = parse(get())
if n == table.getn(mbox) then break end
if socket._time() - sentinel.time > 50 then
if socket.time() - sentinel.time > 50 then
to = 1
break
end
socket._sleep(1)
socket.sleep(1)
io.write(".")
io.stdout:flush()
end
@ -256,4 +256,4 @@ for i = 1, table.getn(mbox) do
end
print("passed all tests")
print(string.format("done in %.2fs", socket._time() - t))
print(string.format("done in %.2fs", socket.time() - t))

View File

@ -43,7 +43,7 @@ function check_timeout(tm, sl, elapsed, err, opp, mode, alldone)
else pass("proper timeout") end
end
else
if mode == "return" then
if mode == "total" then
if elapsed > tm then
if err ~= "timeout" then fail("should have timed out")
else pass("proper timeout") end
@ -66,17 +66,17 @@ function check_timeout(tm, sl, elapsed, err, opp, mode, alldone)
end
end
if not socket.debug then
fail("Please define LUASOCKET_DEBUG and recompile LuaSocket")
end
io.write("----------------------------------------------\n",
"LuaSocket Test Procedures\n",
"----------------------------------------------\n")
if not socket._time or not socket._sleep then
fail("not compiled with _DEBUG")
end
start = socket.time()
start = socket._time()
function tcpreconnect()
function reconnect()
io.write("attempting data connection... ")
if data then data:close() end
remote [[
@ -87,109 +87,85 @@ function tcpreconnect()
if not data then fail(err)
else pass("connected!") end
end
reconnect = tcpreconnect
pass("attempting control connection...")
control, err = socket.connect(host, port)
if err then fail(err)
else pass("connected!") end
------------------------------------------------------------------------
test("bugs")
io.write("empty host connect: ")
function empty_connect()
if data then data:close() data = nil end
remote [[
if data then data:close() data = nil end
data = server:accept()
]]
data, err = socket.connect("", port)
if not data then
pass("ok")
data = socket.connect(host, port)
else fail("should not have connected!") end
end
empty_connect()
io.write("active close: ")
function active_close()
reconnect()
if socket._isclosed(data) then fail("should not be closed") end
data:close()
if not socket._isclosed(data) then fail("should be closed") end
data = nil
local udp = socket.udp()
if socket._isclosed(udp) then fail("should not be closed") end
udp:close()
if not socket._isclosed(udp) then fail("should be closed") end
pass("ok")
end
active_close()
------------------------------------------------------------------------
test("method registration")
function test_methods(sock, methods)
for _, v in methods do
if type(sock[v]) ~= "function" then
fail(type(sock) .. " method " .. v .. "not registered")
fail(sock.class .. " method '" .. v .. "' not registered")
end
end
pass(type(sock) .. " methods are ok")
pass(sock.class .. " methods are ok")
end
test_methods(control, {
"close",
"timeout",
test_methods(socket.tcp(), {
"connect",
"send",
"receive",
"bind",
"accept",
"setpeername",
"setsockname",
"getpeername",
"getsockname"
"getsockname",
"timeout",
"close",
})
if udpsocket then
test_methods(socket.udp(), {
"close",
"timeout",
"send",
"sendto",
"receive",
"receivefrom",
"getpeername",
"getsockname",
"setsockname",
"setpeername"
})
end
test_methods(socket.bind("*", 0), {
"close",
test_methods(socket.udp(), {
"getpeername",
"getsockname",
"setsockname",
"setpeername",
"send",
"sendto",
"receive",
"receivefrom",
"timeout",
"accept"
"close",
})
------------------------------------------------------------------------
test("select function")
function test_selectbugs()
local r, s, e = socket.select(nil, nil, 0.1)
assert(type(r) == "table" and type(s) == "table" and e == "timeout")
pass("both nil: ok")
local udp = socket.udp()
udp:close()
r, s, e = socket.select({ udp }, { udp }, 0.1)
assert(type(r) == "table" and type(s) == "table" and e == "timeout")
pass("closed sockets: ok")
e = pcall(socket.select, "wrong", 1, 0.1)
assert(e == false)
e = pcall(socket.select, {}, 1, 0.1)
assert(e == false)
pass("invalid input: ok")
test("mixed patterns")
function test_mixed(len)
reconnect()
local inter = math.ceil(len/4)
local p1 = "unix " .. string.rep("x", inter) .. "line\n"
local p2 = "dos " .. string.rep("y", inter) .. "line\r\n"
local p3 = "raw " .. string.rep("z", inter) .. "bytes"
local p4 = "end" .. string.rep("w", inter) .. "bytes"
local bp1, bp2, bp3, bp4
pass(len .. " byte(s) patterns")
remote (string.format("str = data:receive(%d)",
string.len(p1)+string.len(p2)+string.len(p3)+string.len(p4)))
sent, err = data:send(p1, p2, p3, p4)
if err then fail(err) end
remote "data:send(str); data:close()"
bp1, bp2, bp3, bp4, err = data:receive("*l", "*l", string.len(p3), "*a")
if err then fail(err) end
if bp1.."\n" == p1 and bp2.."\r\n" == p2 and bp3 == p3 and bp4 == p4 then
pass("patterns match")
else fail("patterns don't match") end
end
test_selectbugs()
test_mixed(1)
test_mixed(17)
test_mixed(200)
test_mixed(4091)
test_mixed(80199)
test_mixed(4091)
test_mixed(200)
test_mixed(17)
test_mixed(1)
------------------------------------------------------------------------
test("character line")
@ -202,7 +178,7 @@ function test_asciiline(len)
str = str .. str10
pass(len .. " byte(s) line")
remote "str = data:receive()"
err = data:send(str, "\n")
sent, err = data:send(str, "\n")
if err then fail(err) end
remote "data:send(str, '\\n')"
back, err = data:receive()
@ -230,7 +206,7 @@ function test_rawline(len)
str = str .. str10
pass(len .. " byte(s) line")
remote "str = data:receive()"
err = data:send(str, "\n")
sent, err = data:send(str, "\n")
if err then fail(err) end
remote "data:send(str, '\\n')"
back, err = data:receive()
@ -262,9 +238,9 @@ function test_raw(len)
s2 = string.rep("y", len-half)
pass(len .. " byte(s) block")
remote (string.format("str = data:receive(%d)", len))
err = data:send(s1)
sent, err = data:send(s1)
if err then fail(err) end
err = data:send(s2)
sent, err = data:send(s2)
if err then fail(err) end
remote "data:send(str)"
back, err = data:receive(len)
@ -304,39 +280,139 @@ test_raw(17)
test_raw(1)
------------------------------------------------------------------------
test("mixed patterns")
reconnect()
test("total timeout on receive")
function test_totaltimeoutreceive(len, tm, sl)
local str, err, total
reconnect()
pass("%d bytes, %ds total timeout, %ds pause", len, tm, sl)
remote (string.format ([[
data:timeout(%d)
str = string.rep('a', %d)
data:send(str)
print('server: sleeping for %ds')
socket.sleep(%d)
print('server: woke up')
data:send(str)
]], 2*tm, len, sl, sl))
data:timeout(tm, "total")
str, err, elapsed = data:receive(2*len)
check_timeout(tm, sl, elapsed, err, "receive", "total",
string.len(str) == 2*len)
end
test_totaltimeoutreceive(800091, 1, 3)
test_totaltimeoutreceive(800091, 2, 3)
test_totaltimeoutreceive(800091, 3, 2)
test_totaltimeoutreceive(800091, 3, 1)
function test_mixed(len)
local inter = math.floor(len/3)
local p1 = "unix " .. string.rep("x", inter) .. "line\n"
local p2 = "dos " .. string.rep("y", inter) .. "line\r\n"
local p3 = "raw " .. string.rep("z", inter) .. "bytes"
local bp1, bp2, bp3
pass(len .. " byte(s) patterns")
remote (string.format("str = data:receive(%d)",
string.len(p1)+string.len(p2)+string.len(p3)))
err = data:send(p1, p2, p3)
if err then fail(err) end
remote "data:send(str)"
bp1, bp2, bp3, err = data:receive("*lu", "*l", string.len(p3))
if err then fail(err) end
if bp1.."\n" == p1 and bp2.."\r\n" == p2 and bp3 == p3 then
pass("patterns match")
else fail("patterns don't match") end
------------------------------------------------------------------------
test("total timeout on send")
function test_totaltimeoutsend(len, tm, sl)
local str, err, total
reconnect()
pass("%d bytes, %ds total timeout, %ds pause", len, tm, sl)
remote (string.format ([[
data:timeout(%d)
str = data:receive(%d)
print('server: sleeping for %ds')
socket.sleep(%d)
print('server: woke up')
str = data:receive(%d)
]], 2*tm, len, sl, sl, len))
data:timeout(tm, "total")
str = string.rep("a", 2*len)
total, err, elapsed = data:send(str)
check_timeout(tm, sl, elapsed, err, "send", "total",
total == 2*len)
end
test_totaltimeoutsend(800091, 1, 3)
test_totaltimeoutsend(800091, 2, 3)
test_totaltimeoutsend(800091, 3, 2)
test_totaltimeoutsend(800091, 3, 1)
------------------------------------------------------------------------
test("blocking timeout on receive")
function test_blockingtimeoutreceive(len, tm, sl)
local str, err, total
reconnect()
pass("%d bytes, %ds blocking timeout, %ds pause", len, tm, sl)
remote (string.format ([[
data:timeout(%d)
str = string.rep('a', %d)
data:send(str)
print('server: sleeping for %ds')
socket.sleep(%d)
print('server: woke up')
data:send(str)
]], 2*tm, len, sl, sl))
data:timeout(tm)
str, err, elapsed = data:receive(2*len)
check_timeout(tm, sl, elapsed, err, "receive", "blocking",
string.len(str) == 2*len)
end
test_blockingtimeoutreceive(800091, 1, 3)
test_blockingtimeoutreceive(800091, 2, 3)
test_blockingtimeoutreceive(800091, 3, 2)
test_blockingtimeoutreceive(800091, 3, 1)
------------------------------------------------------------------------
test("blocking timeout on send")
function test_blockingtimeoutsend(len, tm, sl)
local str, err, total
reconnect()
pass("%d bytes, %ds blocking timeout, %ds pause", len, tm, sl)
remote (string.format ([[
data:timeout(%d)
str = data:receive(%d)
print('server: sleeping for %ds')
socket.sleep(%d)
print('server: woke up')
str = data:receive(%d)
]], 2*tm, len, sl, sl, len))
data:timeout(tm)
str = string.rep("a", 2*len)
total, err, elapsed = data:send(str)
check_timeout(tm, sl, elapsed, err, "send", "blocking",
total == 2*len)
end
test_blockingtimeoutsend(800091, 1, 3)
test_blockingtimeoutsend(800091, 2, 3)
test_blockingtimeoutsend(800091, 3, 2)
test_blockingtimeoutsend(800091, 3, 1)
------------------------------------------------------------------------
test("bugs")
io.write("empty host connect: ")
function empty_connect()
if data then data:close() data = nil end
remote [[
if data then data:close() data = nil end
data = server:accept()
]]
data, err = socket.connect("", port)
if not data then
pass("ok")
data = socket.connect(host, port)
else fail("should not have connected!") end
end
test_mixed(1)
test_mixed(17)
test_mixed(200)
test_mixed(4091)
test_mixed(80199)
test_mixed(800000)
test_mixed(80199)
test_mixed(4091)
test_mixed(200)
test_mixed(17)
test_mixed(1)
empty_connect()
-- io.write("active close: ")
function active_close()
reconnect()
if socket._isclosed(data) then fail("should not be closed") end
data:close()
if not socket._isclosed(data) then fail("should be closed") end
data = nil
local udp = socket.udp()
if socket._isclosed(udp) then fail("should not be closed") end
udp:close()
if not socket._isclosed(udp) then fail("should be closed") end
pass("ok")
end
-- active_close()
------------------------------------------------------------------------
test("closed connection detection")
@ -363,7 +439,7 @@ function test_closed()
data:close()
data = nil
]]
err, total = data:send(string.rep("ugauga", 100000))
total, err = data:send(string.rep("ugauga", 100000))
if not err then
pass("failed: output buffer is at least %d bytes long!", total)
elseif err ~= "closed" then
@ -376,106 +452,26 @@ end
test_closed()
------------------------------------------------------------------------
test("return timeout on receive")
function test_blockingtimeoutreceive(len, tm, sl)
local str, err, total
reconnect()
pass("%d bytes, %ds return timeout, %ds pause", len, tm, sl)
remote (string.format ([[
data:timeout(%d)
str = string.rep('a', %d)
data:send(str)
print('server: sleeping for %ds')
socket._sleep(%d)
print('server: woke up')
data:send(str)
]], 2*tm, len, sl, sl))
data:timeout(tm, "return")
str, err, elapsed = data:receive(2*len)
check_timeout(tm, sl, elapsed, err, "receive", "return",
string.len(str) == 2*len)
test("select function")
function test_selectbugs()
local r, s, e = socket.select(nil, nil, 0.1)
assert(type(r) == "table" and type(s) == "table" and e == "timeout")
pass("both nil: ok")
local udp = socket.udp()
udp:close()
r, s, e = socket.select({ udp }, { udp }, 0.1)
assert(type(r) == "table" and type(s) == "table" and e == "timeout")
pass("closed sockets: ok")
e = pcall(socket.select, "wrong", 1, 0.1)
assert(e == false)
e = pcall(socket.select, {}, 1, 0.1)
assert(e == false)
pass("invalid input: ok")
end
test_blockingtimeoutreceive(800091, 1, 3)
test_blockingtimeoutreceive(800091, 2, 3)
test_blockingtimeoutreceive(800091, 3, 2)
test_blockingtimeoutreceive(800091, 3, 1)
------------------------------------------------------------------------
test("return timeout on send")
function test_returntimeoutsend(len, tm, sl)
local str, err, total
reconnect()
pass("%d bytes, %ds return timeout, %ds pause", len, tm, sl)
remote (string.format ([[
data:timeout(%d)
str = data:receive(%d)
print('server: sleeping for %ds')
socket._sleep(%d)
print('server: woke up')
str = data:receive(%d)
]], 2*tm, len, sl, sl, len))
data:timeout(tm, "return")
str = string.rep("a", 2*len)
err, total, elapsed = data:send(str)
check_timeout(tm, sl, elapsed, err, "send", "return",
total == 2*len)
end
test_returntimeoutsend(800091, 1, 3)
test_returntimeoutsend(800091, 2, 3)
test_returntimeoutsend(800091, 3, 2)
test_returntimeoutsend(800091, 3, 1)
-- test_selectbugs()
------------------------------------------------------------------------
test("blocking timeout on receive")
function test_blockingtimeoutreceive(len, tm, sl)
local str, err, total
reconnect()
pass("%d bytes, %ds blocking timeout, %ds pause", len, tm, sl)
remote (string.format ([[
data:timeout(%d)
str = string.rep('a', %d)
data:send(str)
print('server: sleeping for %ds')
socket._sleep(%d)
print('server: woke up')
data:send(str)
]], 2*tm, len, sl, sl))
data:timeout(tm)
str, err, elapsed = data:receive(2*len)
check_timeout(tm, sl, elapsed, err, "receive", "blocking",
string.len(str) == 2*len)
end
test_blockingtimeoutreceive(800091, 1, 3)
test_blockingtimeoutreceive(800091, 2, 3)
test_blockingtimeoutreceive(800091, 3, 2)
test_blockingtimeoutreceive(800091, 3, 1)
------------------------------------------------------------------------
test("blocking timeout on send")
function test_blockingtimeoutsend(len, tm, sl)
local str, err, total
reconnect()
pass("%d bytes, %ds blocking timeout, %ds pause", len, tm, sl)
remote (string.format ([[
data:timeout(%d)
str = data:receive(%d)
print('server: sleeping for %ds')
socket._sleep(%d)
print('server: woke up')
str = data:receive(%d)
]], 2*tm, len, sl, sl, len))
data:timeout(tm)
str = string.rep("a", 2*len)
err, total, elapsed = data:send(str)
check_timeout(tm, sl, elapsed, err, "send", "blocking",
total == 2*len)
end
test_blockingtimeoutsend(800091, 1, 3)
test_blockingtimeoutsend(800091, 2, 3)
test_blockingtimeoutsend(800091, 3, 2)
test_blockingtimeoutsend(800091, 3, 1)
------------------------------------------------------------------------
test(string.format("done in %.2fs", socket._time() - start))
test(string.format("done in %.2fs", socket.time() - start))

View File

@ -13,12 +13,13 @@ while 1 do
print("server: closing connection...")
break
end
error = control:send("\n")
sent, error = control:send("\n")
if error then
control:close()
print("server: closing connection...")
break
end
print(command);
(loadstring(command))()
end
end

View File

@ -1,5 +1,5 @@
-- load tftpclnt.lua
dofile("tftpclnt.lua")
dofile("tftp.lua")
-- needs tftp server running on localhost, with root pointing to
-- a directory with index.html in it
@ -13,11 +13,8 @@ function readfile(file)
end
host = host or "localhost"
print("downloading")
err = tftp_get(host, 69, "index.html", "index.got")
retrieved, err = socket.tftp.get("tftp://" .. host .."/index.html")
assert(not err, err)
original = readfile("test/index.html")
retrieved = readfile("index.got")
os.remove("index.got")
assert(original == retrieved, "files differ!")
print("passed")