diff --git a/TODO b/TODO index 9e9923c..7479cfc 100644 --- a/TODO +++ b/TODO @@ -19,10 +19,16 @@ * Separar as classes em arquivos * Retorno de sendto em datagram sockets pode ser refused +unify filter and send/receive callback. new sink/source/pump idea. +get rid of aux_optlstring +wrap sink and sources with a function that performs the replacement +get rid of unpack in mime.lua check garbage collection in test*.lua pop3??? +break chain into a simpler binary chain and a complex (recursive) one. + add socket.TIMEOUT to be default timeout? manual diff --git a/doc/reference.html b/doc/reference.html index e6efb6e..6e14891 100644 --- a/doc/reference.html +++ b/doc/reference.html @@ -126,7 +126,7 @@ MIME (socket.mime)
high-level: -canonic, +normalize, chain, decode, encode, diff --git a/etc/eol.lua b/etc/eol.lua index fea5da9..6b2a8a9 100644 --- a/etc/eol.lua +++ b/etc/eol.lua @@ -1,9 +1,9 @@ marker = {['-u'] = '\10', ['-d'] = '\13\10'} arg = arg or {'-u'} marker = marker[arg[1]] or marker['-u'] -local convert = socket.mime.canonic(marker) +local convert = socket.mime.normalize(marker) while 1 do - local chunk = io.read(4096) + local chunk = io.read(1) io.write(convert(chunk)) if not chunk then break end end diff --git a/etc/qp.lua b/etc/qp.lua index 23c834a..1ca0ae2 100644 --- a/etc/qp.lua +++ b/etc/qp.lua @@ -2,10 +2,10 @@ local convert arg = arg or {} local mode = arg and arg[1] or "-et" if mode == "-et" then - local canonic = socket.mime.canonic() + local normalize = socket.mime.normalize() local qp = socket.mime.encode("quoted-printable") local wrap = socket.mime.wrap("quoted-printable") - convert = socket.mime.chain(canonic, qp, wrap) + convert = socket.mime.chain(normalize, qp, wrap) elseif mode == "-eb" then local qp = socket.mime.encode("quoted-printable", "binary") local wrap = socket.mime.wrap("quoted-printable") diff --git a/src/ltn12.lua b/src/ltn12.lua new file mode 100644 index 0000000..548588a --- /dev/null +++ b/src/ltn12.lua @@ -0,0 +1,171 @@ +-- create code namespace inside LuaSocket namespace +ltn12 = ltn12 or {} +-- make all module globals fall into mime namespace +setmetatable(ltn12, { __index = _G }) +setfenv(1, ltn12) + +-- sub namespaces +filter = {} +source = {} +sink = {} + +-- 2048 seems to be better in windows... +BLOCKSIZE = 2048 + +-- returns a high level filter that cycles a cycles a low-level filter +function filter.cycle(low, ctx, extra) + return function(chunk) + local ret + ret, ctx = low(ctx, chunk, extra) + return ret + end +end + +-- chains two filters together +local function chain2(f1, f2) + return function(chunk) + local ret = f2(f1(chunk)) + if chunk then return ret + else return ret .. f2() end + end +end + +-- chains a bunch of filters together +function filter.chain(...) + local f = arg[1] + for i = 2, table.getn(arg) do + f = chain2(f, arg[i]) + end + return f +end + +-- create an empty source +function source.empty(err) + return function() + return nil, err + end +end + +-- creates a file source +function source.file(handle, io_err) + if handle then + return function() + local chunk = handle:read(BLOCKSIZE) + if not chunk then handle:close() end + return chunk + end + else source.empty(io_err or "unable to open file") end +end + +-- turns a fancy source into a simple source +function source.simplify(src) + return function() + local chunk, err_or_new = src() + src = err_or_new or src + if not chunk then return nil, err_or_new + else return chunk end + end +end + +-- creates string source +function source.string(s) + if s then + local i = 1 + return function() + local chunk = string.sub(s, i, i+BLOCKSIZE-1) + i = i + BLOCKSIZE + if chunk ~= "" then return chunk + else return nil end + end + else source.empty() end +end + +-- creates rewindable source +function source.rewind(src) + local t = {} + src = source.simplify(src) + return function(chunk) + if not chunk then + chunk = table.remove(t) + if not chunk then return src() + else return chunk end + else + table.insert(t, chunk) + end + end +end + +-- chains a source with a filter +function source.chain(src, f) + src = source.simplify(src) + local chain = function() + local chunk, err = src() + if not chunk then return f(nil), source.empty(err) + else return f(chunk) end + end + return source.simplify(chain) +end + +-- creates a sink that stores into a table +function sink.table(t) + t = t or {} + local f = function(chunk, err) + if chunk then table.insert(t, chunk) end + return 1 + end + return f, t +end + +-- turns a fancy sink into a simple sink +function sink.simplify(snk) + return function(chunk, err) + local ret, err_or_new = snk(chunk, err) + if not ret then return nil, err_or_new end + snk = err_or_new or snk + return 1 + end +end + +-- creates a file sink +function sink.file(handle, io_err) + if handle then + return function(chunk, err) + if not chunk then + handle:close() + return nil, err + end + return handle:write(chunk) + end + else sink.null() end +end + +-- creates a sink that discards data +local function null() + return 1 +end + +function sink.null() + return null +end + +-- chains a sink with a filter +function sink.chain(f, snk) + snk = sink.simplify(snk) + return function(chunk, err) + local r, e = snk(f(chunk)) + if not r then return nil, e end + if not chunk then return snk(nil, err) end + return 1 + end +end + +-- pumps all data from a source to a sink +function pump(src, snk) + snk = sink.simplify(snk) + for chunk, src_err in source.simplify(src) do + local ret, snk_err = snk(chunk, src_err) + if not chunk or not ret then + return not src_err and not snk_err, src_err or snk_err + end + end +end diff --git a/src/luasocket.c b/src/luasocket.c index e99fcdf..47696cb 100644 --- a/src/luasocket.c +++ b/src/luasocket.c @@ -72,6 +72,7 @@ static int mod_open(lua_State *L, const luaL_reg *mod) { for (; mod->name; mod++) mod->func(L); #ifdef LUASOCKET_COMPILED +#include "ltn12.lch" #include "auxiliar.lch" #include "concat.lch" #include "url.lch" @@ -81,6 +82,7 @@ static int mod_open(lua_State *L, const luaL_reg *mod) #include "ftp.lch" #include "http.lch" #else + lua_dofile(L, "ltn12.lua"); lua_dofile(L, "auxiliar.lua"); lua_dofile(L, "concat.lua"); lua_dofile(L, "url.lua"); diff --git a/src/mime.c b/src/mime.c index 9fc4f51..ae4084d 100644 --- a/src/mime.c +++ b/src/mime.c @@ -9,8 +9,6 @@ #include #include -#include "luasocket.h" -#include "auxiliar.h" #include "mime.h" /*=========================================================================*\ @@ -83,12 +81,10 @@ static UC b64unbase[256]; \*-------------------------------------------------------------------------*/ int mime_open(lua_State *L) { - lua_pushstring(L, LUASOCKET_LIBNAME); - lua_gettable(L, LUA_GLOBALSINDEX); lua_pushstring(L, "mime"); lua_newtable(L); luaL_openlib(L, NULL, func, 0); - lua_settable(L, -3); + lua_settable(L, LUA_GLOBALSINDEX); lua_pop(L, 1); /* initialize lookup tables */ qpsetup(qpclass, qpunbase); @@ -110,7 +106,7 @@ static int mime_global_wrp(lua_State *L) { size_t size = 0; int left = (int) luaL_checknumber(L, 1); - const UC *input = (UC *) aux_optlstring(L, 2, NULL, &size); + const UC *input = (UC *) luaL_optlstring(L, 2, NULL, &size); const UC *last = input + size; int length = (int) luaL_optnumber(L, 3, 76); luaL_Buffer buffer; @@ -261,7 +257,7 @@ static int mime_global_b64(lua_State *L) luaL_buffinit(L, &buffer); while (input < last) asize = b64encode(*input++, atom, asize, &buffer); - input = (UC *) aux_optlstring(L, 2, NULL, &isize); + input = (UC *) luaL_optlstring(L, 2, NULL, &isize); if (input) { last = input + isize; while (input < last) @@ -289,7 +285,7 @@ static int mime_global_unb64(lua_State *L) luaL_buffinit(L, &buffer); while (input < last) asize = b64decode(*input++, atom, asize, &buffer); - input = (UC *) aux_optlstring(L, 2, NULL, &isize); + input = (UC *) luaL_optlstring(L, 2, NULL, &isize); if (input) { last = input + isize; while (input < last) @@ -426,14 +422,14 @@ static int mime_global_qp(lua_State *L) size_t asize = 0, isize = 0; UC atom[3]; - const UC *input = (UC *) aux_optlstring(L, 1, NULL, &isize); + const UC *input = (UC *) luaL_optlstring(L, 1, NULL, &isize); const UC *last = input + isize; const char *marker = luaL_optstring(L, 3, CRLF); luaL_Buffer buffer; luaL_buffinit(L, &buffer); while (input < last) asize = qpencode(*input++, atom, asize, marker, &buffer); - input = (UC *) aux_optlstring(L, 2, NULL, &isize); + input = (UC *) luaL_optlstring(L, 2, NULL, &isize); if (input) { last = input + isize; while (input < last) @@ -495,13 +491,13 @@ static int mime_global_unqp(lua_State *L) size_t asize = 0, isize = 0; UC atom[3]; - const UC *input = (UC *) aux_optlstring(L, 1, NULL, &isize); + const UC *input = (UC *) luaL_optlstring(L, 1, NULL, &isize); const UC *last = input + isize; luaL_Buffer buffer; luaL_buffinit(L, &buffer); while (input < last) asize = qpdecode(*input++, atom, asize, &buffer); - input = (UC *) aux_optlstring(L, 2, NULL, &isize); + input = (UC *) luaL_optlstring(L, 2, NULL, &isize); if (input) { last = input + isize; while (input < last) @@ -525,7 +521,7 @@ static int mime_global_qpwrp(lua_State *L) { size_t size = 0; int left = (int) luaL_checknumber(L, 1); - const UC *input = (UC *) aux_optlstring(L, 2, NULL, &size); + const UC *input = (UC *) luaL_optlstring(L, 2, NULL, &size); const UC *last = input + size; int length = (int) luaL_optnumber(L, 3, 76); luaL_Buffer buffer; @@ -576,54 +572,52 @@ static int mime_global_qpwrp(lua_State *L) * probably other more obscure conventions. \*-------------------------------------------------------------------------*/ #define eolcandidate(c) (c == CR || c == LF) -static size_t eolconvert(UC c, UC *input, size_t size, - const char *marker, luaL_Buffer *buffer) +static size_t eolprocess(int c, int ctx, const char *marker, + luaL_Buffer *buffer) { - input[size++] = c; - /* deal with all characters we can deal */ - if (eolcandidate(input[0])) { - if (size < 2) return size; + if (eolcandidate(ctx)) { luaL_addstring(buffer, marker); - if (eolcandidate(input[1])) { - if (input[0] == input[1]) luaL_addstring(buffer, marker); - } else luaL_putchar(buffer, input[1]); - return 0; + if (eolcandidate(c)) { + if (c == ctx) + luaL_addstring(buffer, marker); + return 0; + } else { + luaL_putchar(buffer, c); + return 0; + } } else { - luaL_putchar(buffer, input[0]); - return 0; + if (!eolcandidate(c)) { + luaL_putchar(buffer, c); + return 0; + } else + return c; } } /*-------------------------------------------------------------------------*\ * Converts a string to uniform EOL convention. -* A, B = eol(C, D, marker) -* A is the converted version of the largest prefix of C .. D that -* can be converted without doubts. -* B has the remaining bytes of C .. D, *without* convertion. +* A, n = eol(o, B, marker) +* A is the converted version of the largest prefix of B that can be +* converted unambiguously. 'o' is the context returned by the previous +* call. 'n' is the new context. \*-------------------------------------------------------------------------*/ static int mime_global_eol(lua_State *L) { - size_t asize = 0, isize = 0; - UC atom[2]; - const UC *input = (UC *) aux_optlstring(L, 1, NULL, &isize); - const UC *last = input + isize; + int ctx = luaL_checkint(L, 1); + size_t isize = 0; + const char *input = luaL_optlstring(L, 2, NULL, &isize); + const char *last = input + isize; const char *marker = luaL_optstring(L, 3, CRLF); luaL_Buffer buffer; luaL_buffinit(L, &buffer); while (input < last) - asize = eolconvert(*input++, atom, asize, marker, &buffer); - input = (UC *) aux_optlstring(L, 2, NULL, &isize); - if (input) { - last = input + isize; - while (input < last) - asize = eolconvert(*input++, atom, asize, marker, &buffer); - /* if there is something in atom, it's one character, and it - * is a candidate. so we output a new line */ - } else if (asize > 0) { - luaL_addstring(&buffer, marker); - asize = 0; + ctx = eolprocess(*input++, ctx, marker, &buffer); + /* if the last character was a candidate, we output a new line */ + if (!input) { + if (eolcandidate(ctx)) luaL_addstring(&buffer, marker); + ctx = 0; } luaL_pushresult(&buffer); - lua_pushlstring(L, (char *) atom, asize); + lua_pushnumber(L, ctx); return 2; } diff --git a/src/mime.lua b/src/mime.lua index 369567f..4df0388 100644 --- a/src/mime.lua +++ b/src/mime.lua @@ -1,11 +1,6 @@ --- make sure LuaSocket is loaded -if not LUASOCKET_LIBNAME then error('module requires LuaSocket') end --- get LuaSocket namespace -local socket = _G[LUASOCKET_LIBNAME] -if not socket then error('module requires LuaSocket') end --- create code namespace inside LuaSocket namespace -local mime = socket.mime or {} -socket.mime = mime +if not ltn12 then error('This module requires LTN12') end +-- create mime namespace +mime = mime or {} -- make all module globals fall into mime namespace setmetatable(mime, { __index = _G }) setfenv(1, mime) @@ -15,80 +10,61 @@ local et = {} local dt = {} local wt = {} --- creates a function that chooses a filter from a given table +-- creates a function that chooses a filter by name from a given table local function choose(table) - return function(filter, ...) - local f = table[filter or "nil"] - if not f then error("unknown filter (" .. tostring(filter) .. ")", 3) - else return f(unpack(arg)) end + return function(name, opt) + local f = table[name or "nil"] + if not f then error("unknown filter (" .. tostring(name) .. ")", 3) + else return f(opt) end end end -- define the encoding filters et['base64'] = function() - return socket.cicle(b64, "") + return ltn12.filter.cycle(b64, "") end et['quoted-printable'] = function(mode) - return socket.cicle(qp, "", (mode == "binary") and "=0D=0A" or "\13\10") + return ltn12.filter.cycle(qp, "", + (mode == "binary") and "=0D=0A" or "\13\10") end -- define the decoding filters dt['base64'] = function() - return socket.cicle(unb64, "") + return ltn12.filter.cycle(unb64, "") end dt['quoted-printable'] = function() - return socket.cicle(unqp, "") + return ltn12.filter.cycle(unqp, "") end -- define the line-wrap filters wt['text'] = function(length) length = length or 76 - return socket.cicle(wrp, length, length) + return ltn12.filter.cycle(wrp, length, length) end wt['base64'] = wt['text'] wt['quoted-printable'] = function() - return socket.cicle(qpwrp, 76, 76) + return ltn12.filter.cycle(qpwrp, 76, 76) end -- function that choose the encoding, decoding or wrap algorithm encode = choose(et) decode = choose(dt) --- there is a default wrap filter +-- there is different because there is a default wrap filter local cwt = choose(wt) -function wrap(...) - if type(arg[1]) ~= "string" then table.insert(arg, 1, "text") end - return cwt(unpack(arg)) -end - --- define the end-of-line translation filter -function canonic(marker) - return socket.cicle(eol, "", marker) -end - --- chains several filters together -function chain(...) - local layers = table.getn(arg) - return function (chunk) - if not chunk then - local parts = {} - for i = 1, layers do - for j = i, layers do - chunk = arg[j](chunk) - end - table.insert(parts, chunk) - chunk = nil - end - return table.concat(parts) - else - for j = 1, layers do - chunk = arg[j](chunk) - end - return chunk - end +function wrap(mode_or_length, length) + if type(mode_or_length) ~= "string" then + length = mode_or_length + mode_or_length = "text" end + return cwt(mode_or_length, length) +end + +-- define the end-of-line normalization filter +function normalize(marker) + return ltn12.filter.cycle(eol, 0, marker) end return mime diff --git a/src/wsocket.c b/src/wsocket.c index fcd2fff..2993c35 100644 --- a/src/wsocket.c +++ b/src/wsocket.c @@ -191,7 +191,7 @@ int sock_send(p_sock ps, const char *data, size_t count, size_t *sent, int timeout) { t_sock sock = *ps; - ssize_t put; + int put; int ret; /* avoid making system calls on closed sockets */ if (sock == SOCK_INVALID) return IO_CLOSED; @@ -227,7 +227,7 @@ int sock_sendto(p_sock ps, const char *data, size_t count, size_t *sent, SA *addr, socklen_t addr_len, int timeout) { t_sock sock = *ps; - ssize_t put; + int put; int ret; /* avoid making system calls on closed sockets */ if (sock == SOCK_INVALID) return IO_CLOSED; @@ -262,7 +262,7 @@ int sock_sendto(p_sock ps, const char *data, size_t count, size_t *sent, int sock_recv(p_sock ps, char *data, size_t count, size_t *got, int timeout) { t_sock sock = *ps; - ssize_t taken; + int taken; if (sock == SOCK_INVALID) return IO_CLOSED; taken = recv(sock, data, (int) count, 0); if (taken <= 0) { @@ -288,7 +288,7 @@ int sock_recvfrom(p_sock ps, char *data, size_t count, size_t *got, SA *addr, socklen_t *addr_len, int timeout) { t_sock sock = *ps; - ssize_t taken; + int taken; if (sock == SOCK_INVALID) return IO_CLOSED; taken = recvfrom(sock, data, (int) count, 0, addr, addr_len); if (taken <= 0) { diff --git a/src/wsocket.h b/src/wsocket.h index d77841e..c048c58 100644 --- a/src/wsocket.h +++ b/src/wsocket.h @@ -13,7 +13,6 @@ #include typedef int socklen_t; -typedef int ssize_t; typedef SOCKET t_sock; typedef t_sock *p_sock; diff --git a/test/ltn12test.lua b/test/ltn12test.lua new file mode 100644 index 0000000..1c1f3f2 --- /dev/null +++ b/test/ltn12test.lua @@ -0,0 +1,3 @@ +sink = ltn12.sink.file(io.open("lixo", "w")) +source = ltn12.source.file(io.open("luasocket", "r")) +ltn12.pump(source, sink) diff --git a/test/mimetest.lua b/test/mimetest.lua index 5a461fa..1a7e427 100644 --- a/test/mimetest.lua +++ b/test/mimetest.lua @@ -1,4 +1,4 @@ -dofile("noglobals.lua") +dofile("testsupport.lua") local qptest = "qptest.bin" local eqptest = "qptest.bin2" @@ -31,26 +31,13 @@ local mao = [[ assim, nem tudo o que dava exprimia grande confiança. ]] -local fail = function(s) - s = s or "failed" - assert(nil, s) -end - -local readfile = function(name) - local f = io.open(name, "r") - if not f then return nil end - local s = f:read("*a") - f:close() - return s -end - local function transform(input, output, filter) local fi, err = io.open(input, "rb") if not fi then fail(err) end local fo, err = io.open(output, "wb") if not fo then fail(err) end while 1 do - local chunk = fi:read(math.random(0, 256)) + local chunk = fi:read(math.random(0, 1024)) fo:write(filter(chunk)) if not chunk then break end end @@ -58,17 +45,10 @@ local function transform(input, output, filter) fo:close() end -local function compare(input, output) - local original = readfile(input) - local recovered = readfile(output) - if original ~= recovered then fail("recovering failed") - else print("ok") end -end - local function encode_qptest(mode) - local encode = socket.mime.encode("quoted-printable", mode) - local split = socket.mime.wrap("quoted-printable") - local chain = socket.mime.chain(encode, split) + local encode = mime.encode("quoted-printable", mode) + local split = mime.wrap("quoted-printable") + local chain = ltn12.filter.chain(encode, split) transform(qptest, eqptest, chain) end @@ -77,7 +57,7 @@ local function compare_qptest() end local function decode_qptest() - local decode = socket.mime.decode("quoted-printable") + local decode = mime.decode("quoted-printable") transform(eqptest, dqptest, decode) end @@ -151,24 +131,24 @@ local function cleanup_qptest() end local function encode_b64test() - local e1 = socket.mime.encode("base64") - local e2 = socket.mime.encode("base64") - local e3 = socket.mime.encode("base64") - local e4 = socket.mime.encode("base64") - local sp4 = socket.mime.wrap() - local sp3 = socket.mime.wrap(59) - local sp2 = socket.mime.wrap("base64", 30) - local sp1 = socket.mime.wrap(27) - local chain = socket.mime.chain(e1, sp1, e2, sp2, e3, sp3, e4, sp4) + local e1 = mime.encode("base64") + local e2 = mime.encode("base64") + local e3 = mime.encode("base64") + local e4 = mime.encode("base64") + local sp4 = mime.wrap() + local sp3 = mime.wrap(59) + local sp2 = mime.wrap("base64", 30) + local sp1 = mime.wrap(27) + local chain = ltn12.filter.chain(e1, sp1, e2, sp2, e3, sp3, e4, sp4) transform(b64test, eb64test, chain) end local function decode_b64test() - local d1 = socket.mime.decode("base64") - local d2 = socket.mime.decode("base64") - local d3 = socket.mime.decode("base64") - local d4 = socket.mime.decode("base64") - local chain = socket.mime.chain(d1, d2, d3, d4) + local d1 = mime.decode("base64") + local d2 = mime.decode("base64") + local d3 = mime.decode("base64") + local d4 = mime.decode("base64") + local chain = ltn12.filter.chain(d1, d2, d3, d4) transform(eb64test, db64test, chain) end @@ -182,11 +162,11 @@ local function compare_b64test() end local function identity_test() - local chain = socket.mime.chain( - socket.mime.encode("quoted-printable"), - socket.mime.encode("base64"), - socket.mime.decode("base64"), - socket.mime.decode("quoted-printable") + local chain = ltn12.filter.chain( + mime.encode("quoted-printable"), + mime.encode("base64"), + mime.decode("base64"), + mime.decode("quoted-printable") ) transform(b64test, eb64test, chain) compare(b64test, eb64test) @@ -195,8 +175,8 @@ end local function padcheck(original, encoded) - local e = (socket.mime.b64(original)) - local d = (socket.mime.unb64(encoded)) + local e = (mime.b64(original)) + local d = (mime.unb64(encoded)) if e ~= encoded then fail("encoding failed") end if d ~= original then fail("decoding failed") end end @@ -206,8 +186,8 @@ local function chunkcheck(original, encoded) for i = 0, len do local a = string.sub(original, 1, i) local b = string.sub(original, i+1) - local e, r = socket.mime.b64(a, b) - local f = (socket.mime.b64(r)) + local e, r = mime.b64(a, b) + local f = (mime.b64(r)) if (e .. f ~= encoded) then fail(e .. f) end end end @@ -231,6 +211,13 @@ end local t = socket.time() +identity_test() +encode_b64test() +decode_b64test() +compare_b64test() +cleanup_b64test() +padding_b64test() + create_qptest() encode_qptest() decode_qptest() @@ -240,12 +227,4 @@ decode_qptest() compare_qptest() cleanup_qptest() -encode_b64test() -decode_b64test() -compare_b64test() -cleanup_b64test() -padding_b64test() - -identity_test() - print(string.format("done in %.2fs", socket.time() - t))