diff --git a/TODO b/TODO index ab88dbd..4b475ae 100644 --- a/TODO +++ b/TODO @@ -1,3 +1,10 @@ + +comment the need of a content-length header in the post method... + +comment the callback.lua module and the new mime module. + escape and unescape are missing! + +add _tostring methods! add callback module to manual change stay to redirect in http.lua and in manual add timeout to request table diff --git a/doc/reference.css b/doc/reference.css index 4f17046..cd7de2c 100644 --- a/doc/reference.css +++ b/doc/reference.css @@ -16,7 +16,6 @@ blockquote { margin-left: 3em; } a[href] { color: #00007f; } p.name { - font-size: large; font-family: monospace; padding-top: 1em; } diff --git a/doc/reference.html b/doc/reference.html index 08fd068..99b1ea7 100644 --- a/doc/reference.html +++ b/doc/reference.html @@ -92,7 +92,7 @@ - +
    @@ -135,7 +135,7 @@ - +
      @@ -158,7 +158,7 @@ - +
        diff --git a/doc/stream.html b/doc/stream.html index 296ca2e..b88cbb5 100644 --- a/doc/stream.html +++ b/doc/stream.html @@ -36,21 +36,21 @@

        Streaming with Callbacks

        -HTTP and FTP transfers sometimes involve large amounts of information. -Sometimes an application needs to generate outgoing data in real time, -or needs to process incoming information as it is being received. To -address these problems, LuaSocket allows HTTP message bodies and FTP -file contents to be received or sent through the callback mechanism -outlined below. +HTTP, FTP, and SMTP transfers sometimes involve large amounts of +information. Sometimes an application needs to generate outgoing data +in real time, or needs to process incoming information as it is being +received. To address these problems, LuaSocket allows HTTP and SMTP message +bodies and FTP file contents to be received or sent through the +callback mechanism outlined below.

        -Instead of returning the entire contents of a FTP file or HTTP message -body as strings to the Lua application, the library allows the user to +Instead of returning the entire contents of an entity +as strings to the Lua application, the library allows the user to provide a receive callback that will be called with successive chunks of data, as the data becomes available. Conversely, the send -callbacks should be used when data needed by LuaSocket -is generated incrementally by the application. +callbacks can be used when the application wants to incrementally +provide LuaSocket with the data to be sent.

        @@ -68,21 +68,26 @@ callback receives successive chunks of downloaded data.

        Chunk contains the current chunk of data. When the transmission is over, the function is called with an -empty string (i.e. "") as the chunk. If an error occurs, the -function receives nil as chunk and an error message as -err. +empty string (i.e. "") as the chunk. +If an error occurs, the function receives nil +as chunk and an error message in err.

        -The callback can abort transmission by returning -nil as its first return value. In that case, it can also return -an error message. Any non-nil return value proceeds with the -transmission. +The callback can abort transmission by returning nil as its first +return value, and an optional error message as the +second return value. If the application wants to continue receiving +data, the function should return non-nil as it's first return +value. In this case, the function can optionally return a +new callback function, to replace itself, as the second return value. +

        + +

        +Note: The callback module provides several standard receive callbacks, including the following:

        --- The implementation of socket.callback.receive_concat
        -function Public.receive_concat(concat)
        +function receive.concat(concat)
             concat = concat or socket.concat.create()
             local callback = function(chunk, err)
                 -- if not finished, add chunk
        @@ -95,6 +100,12 @@ function Public.receive_concat(concat)
         end
         
        +

        +This function creates a new receive callback that concatenates all +received chunks into a the same concat object, which can later be +queried for its contents. +

        +

        @@ -107,45 +118,27 @@ library needs more data to be sent.

        -Each time the callback is called, it -should return the next part of the information the library is expecting, -followed by the total number of bytes to be sent. -The callback can abort -the process at any time by returning nil followed by an -optional error message. +Each time the callback is called, it should return the next chunk of data. It +can optionally return, as it's second return value, a new callback to replace +itself. The callback can abort the process at any time by returning +nil followed by an optional error message.

        -

        -Note: The need for the second return value comes from the fact that, with -the HTTP protocol for instance, the library needs to know in advance the -total number of bytes that will be sent. +Note: Below is the implementation of the callback.send.file +function. Given an open file handle, it returns a send callback that will send the contents of that file, chunk by chunk.

        --- The implementation of socket.callback.send_file
        -function Public.send_file(file)
        -    local callback
        -    -- if successfull, return the callback that reads from the file
        +function send.file(file, io_err)
        +    -- if successful, return the callback that reads from the file
             if file then
        -        -- get total size
        -        local size = file:seek("end") 
        -        -- go back to start of file
        -        file:seek("set")
        -        callback = function()
        +        return function()
                     -- send next block of data
        -            local chunk = file:read(Public.BLOCKSIZE)
        -            if not chunk then file:close() end
        -            return chunk, size
        +            return (file:read(BLOCKSIZE)) or ""
                 end
             -- else, return a callback that just aborts the transfer
        -    else
        -        callback = function() 
        -            -- just abort
        -            return nil, "unable to open file"
        -        end
        -    end
        -    return callback, file
        +    else return fail(io_err or "unable to open file") end
         end
         
        diff --git a/etc/eol.lua b/etc/eol.lua index 234cc4d..fea5da9 100644 --- a/etc/eol.lua +++ b/etc/eol.lua @@ -1,7 +1,7 @@ marker = {['-u'] = '\10', ['-d'] = '\13\10'} arg = arg or {'-u'} marker = marker[arg[1]] or marker['-u'] -local convert = socket.code.canonic(marker) +local convert = socket.mime.canonic(marker) while 1 do local chunk = io.read(4096) io.write(convert(chunk)) diff --git a/etc/get.lua b/etc/get.lua index 2d804a0..9f29a51 100644 --- a/etc/get.lua +++ b/etc/get.lua @@ -42,8 +42,8 @@ function nicesize(b) end -- returns a string with the current state of the download -function gauge(got, dt, size) - local rate = got / dt +function gauge(got, delta, size) + local rate = got / delta if size and size >= 1 then return string.format("%s received, %s/s throughput, " .. "%.0f%% done, %s remaining", @@ -55,53 +55,56 @@ function gauge(got, dt, size) return string.format("%s received, %s/s throughput, %s elapsed", nicesize(got), nicesize(rate), - nicetime(dt)) + nicetime(delta)) end end -- creates a new instance of a receive_cb that saves to disk -- kind of copied from luasocket's manual callback examples -function receive2disk(file, size) - local aux = { - start = socket.time(), - got = 0, - file = io.open(file, "wb"), - size = size - } - local receive_cb = function(chunk, err) - local dt = socket.time() - aux.start -- elapsed time since start - if not chunk or chunk == "" then - io.write("\n") - aux.file:close() - return +function stats(size) + local start = socket.time() + local got = 0 + return function(chunk) + -- elapsed time since start + local delta = socket.time() - start + if chunk then + -- total bytes received + got = got + string.len(chunk) + -- not enough time for estimate + if delta > 0.1 then + io.stderr:write("\r", gauge(got, delta, size)) + io.stderr:flush() + end + return chunk + else + -- close up + io.stderr:write("\n") + return "" end - aux.file:write(chunk) - aux.got = aux.got + string.len(chunk) -- total bytes received - if dt < 0.1 then return 1 end -- not enough time for estimate - io.write("\r", gauge(aux.got, dt, aux.size)) - return 1 end - return receive_cb end -- downloads a file using the ftp protocol function getbyftp(url, file) + local save = socket.callback.receive.file(file or io.stdout) + if file then + save = socket.callback.receive.chain(stats(gethttpsize(url)), save) + end local err = socket.ftp.get_cb { url = url, - content_cb = receive2disk(file), + content_cb = save, type = "i" } - print() if err then print(err) end end -- downloads a file using the http protocol -function getbyhttp(url, file, size) - local response = socket.http.request_cb( - {url = url}, - {body_cb = receive2disk(file, size)} - ) - print() +function getbyhttp(url, file) + local save = socket.callback.receive.file(file or io.stdout) + if file then + save = socket.callback.receive.chain(stats(gethttpsize(url)), save) + end + local response = socket.http.request_cb({url = url}, {body_cb = save}) if response.code ~= 200 then print(response.status or response.error) end end @@ -116,26 +119,22 @@ function gethttpsize(url) end end --- determines the scheme and the file name of a given url -function getschemeandname(url, name) +-- determines the scheme +function getscheme(url) -- this is an heuristic to solve a common invalid url poblem if not string.find(url, "//") then url = "//" .. url end local parsed = socket.url.parse(url, {scheme = "http"}) - if name then return parsed.scheme, name end - local segment = socket.url.parse_path(parsed.path) - name = segment[table.getn(segment)] - if segment.is_directory then name = nil end - return parsed.scheme, name + return parsed.scheme end -- gets a file either by http or ftp, saving as function get(url, name) - local scheme - scheme, name = getschemeandname(url, name) - if not name then print("unknown file name") - elseif scheme == "ftp" then getbyftp(url, name) - elseif scheme == "http" then getbyhttp(url, name, gethttpsize(url)) + local fout = name and io.open(name, "wb") + local scheme = getscheme(url) + if scheme == "ftp" then getbyftp(url, fout) + elseif scheme == "http" then getbyhttp(url, fout) else print("unknown scheme" .. scheme) end + if name then fout:close() end end -- main program diff --git a/src/http.lua b/src/http.lua index fb13d99..72bde0a 100644 --- a/src/http.lua +++ b/src/http.lua @@ -421,7 +421,7 @@ end ----------------------------------------------------------------------------- local function authorize(reqt, parsed, respt) reqt.headers["authorization"] = "Basic " .. - (socket.code.b64(parsed.user .. ":" .. parsed.password)) + (socket.mime.b64(parsed.user .. ":" .. parsed.password)) local autht = { nredirects = reqt.nredirects, method = reqt.method, @@ -429,8 +429,8 @@ local function authorize(reqt, parsed, respt) body_cb = reqt.body_cb, headers = reqt.headers, timeout = reqt.timeout, - host = reqt.host, - port = reqt.port + proxyhost = reqt.proxyhost, + proxyport = reqt.proxyport } return request_cb(autht, respt) end @@ -471,8 +471,8 @@ local function redirect(reqt, respt) body_cb = reqt.body_cb, headers = reqt.headers, timeout = reqt.timeout, - host = reqt.host, - port = reqt.port + proxyhost = reqt.proxyhost, + proxyport = reqt.proxyport } respt = request_cb(redirt, respt) -- we pass the location header as a clue we tried to redirect @@ -482,8 +482,8 @@ end ----------------------------------------------------------------------------- -- Computes the request URI from the parsed request URL --- If host and port are given in the request table, we use he --- absoluteURI format. Otherwise, we use the abs_path format. +-- If we are using a proxy, we use the absoluteURI format. +-- Otherwise, we use the abs_path format. -- Input -- parsed: parsed URL -- Returns @@ -491,7 +491,7 @@ end ----------------------------------------------------------------------------- local function request_uri(reqt, parsed) local url - if not reqt.host and not reqt.port then + if not reqt.proxyhost and not reqt.proxyport then url = { path = parsed.path, params = parsed.params, @@ -543,6 +543,7 @@ end -- error: error message, or nil if successfull ----------------------------------------------------------------------------- function request_cb(reqt, respt) + local sock, ret local parsed = socket.url.parse(reqt.url, { host = "", port = PORT, @@ -561,14 +562,14 @@ function request_cb(reqt, respt) -- fill default headers reqt.headers = fill_headers(reqt.headers, parsed) -- try to connect to server - local sock sock, respt.error = socket.tcp() if not sock then return respt end -- set connection timeout so that we do not hang forever sock:settimeout(reqt.timeout or TIMEOUT) - local ret - ret, respt.error = sock:connect(reqt.host or parsed.host, - reqt.port or parsed.port) + ret, respt.error = sock:connect( + reqt.proxyhost or PROXYHOST or parsed.host, + reqt.proxyport or PROXYPORT or parsed.port + ) if not ret then sock:close() return respt diff --git a/src/inet.c b/src/inet.c index 282d616..6aea596 100644 --- a/src/inet.c +++ b/src/inet.c @@ -234,7 +234,7 @@ const char *inet_trybind(p_sock ps, const char *address, unsigned short port, return sock_bindstrerror(); } else { sock_setnonblocking(ps); - if (backlog > 0) sock_listen(ps, backlog); + if (backlog >= 0) sock_listen(ps, backlog); return NULL; } } diff --git a/src/luasocket.c b/src/luasocket.c index 8b30f4d..578d65c 100644 --- a/src/luasocket.c +++ b/src/luasocket.c @@ -33,7 +33,7 @@ #include "tcp.h" #include "udp.h" #include "select.h" -#include "code.h" +#include "mime.h" /*=========================================================================*\ * Exported functions @@ -52,13 +52,13 @@ LUASOCKET_API int luaopen_socket(lua_State *L) tcp_open(L); udp_open(L); select_open(L); - code_open(L); + mime_open(L); #ifdef LUASOCKET_COMPILED #include "auxiliar.lch" #include "concat.lch" #include "url.lch" #include "callback.lch" -#include "code.lch" +#include "mime.lch" #include "smtp.lch" #include "ftp.lch" #include "http.lch" @@ -67,7 +67,7 @@ LUASOCKET_API int luaopen_socket(lua_State *L) lua_dofile(L, "concat.lua"); lua_dofile(L, "url.lua"); lua_dofile(L, "callback.lua"); - lua_dofile(L, "code.lua"); + lua_dofile(L, "mime.lua"); lua_dofile(L, "smtp.lua"); lua_dofile(L, "ftp.lua"); lua_dofile(L, "http.lua"); diff --git a/src/mime.c b/src/mime.c new file mode 100644 index 0000000..6807af5 --- /dev/null +++ b/src/mime.c @@ -0,0 +1,614 @@ +/*=========================================================================*\ +* Encoding support functions +* LuaSocket toolkit +* +* RCS ID: $Id$ +\*=========================================================================*/ +#include + +#include +#include + +#include "luasocket.h" +#include "mime.h" + +/*=========================================================================*\ +* Don't want to trust escape character constants +\*=========================================================================*/ +#define CR 0x0D +#define LF 0x0A +#define HT 0x09 +#define SP 0x20 + +typedef unsigned char UC; +static const UC CRLF[2] = {CR, LF}; +static const UC EQCRLF[3] = {'=', CR, LF}; + +/*=========================================================================*\ +* Internal function prototypes. +\*=========================================================================*/ +static int mime_global_fmt(lua_State *L); +static int mime_global_b64(lua_State *L); +static int mime_global_unb64(lua_State *L); +static int mime_global_qp(lua_State *L); +static int mime_global_unqp(lua_State *L); +static int mime_global_qpfmt(lua_State *L); +static int mime_global_eol(lua_State *L); + +static void b64fill(UC *b64unbase); +static size_t b64encode(UC c, UC *input, size_t size, luaL_Buffer *buffer); +static size_t b64pad(const UC *input, size_t size, luaL_Buffer *buffer); +static size_t b64decode(UC c, UC *input, size_t size, luaL_Buffer *buffer); + +static void qpfill(UC *qpclass, UC *qpunbase); +static void qpquote(UC c, luaL_Buffer *buffer); +static size_t qpdecode(UC c, UC *input, size_t size, luaL_Buffer *buffer); +static size_t qpencode(UC c, UC *input, size_t size, + const UC *marker, luaL_Buffer *buffer); + +/* code support functions */ +static luaL_reg func[] = { + { "eol", mime_global_eol }, + { "qp", mime_global_qp }, + { "unqp", mime_global_unqp }, + { "qpfmt", mime_global_qpfmt }, + { "b64", mime_global_b64 }, + { "unb64", mime_global_unb64 }, + { "fmt", mime_global_fmt }, + { NULL, NULL } +}; + +/*-------------------------------------------------------------------------*\ +* Quoted-printable globals +\*-------------------------------------------------------------------------*/ +static UC qpclass[256]; +static UC qpbase[] = "0123456789ABCDEF"; +static UC qpunbase[256]; +enum {QP_PLAIN, QP_QUOTED, QP_CR, QP_IF_LAST}; + +/*-------------------------------------------------------------------------*\ +* Base64 globals +\*-------------------------------------------------------------------------*/ +static const UC b64base[] = + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; +static UC b64unbase[256]; + +/*=========================================================================*\ +* Exported functions +\*=========================================================================*/ +/*-------------------------------------------------------------------------*\ +* Initializes module +\*-------------------------------------------------------------------------*/ +void mime_open(lua_State *L) +{ + lua_pushstring(L, LUASOCKET_LIBNAME); + lua_gettable(L, LUA_GLOBALSINDEX); + if (lua_isnil(L, -1)) { + lua_pop(L, 1); + lua_newtable(L); + lua_pushstring(L, LUASOCKET_LIBNAME); + lua_pushvalue(L, -2); + lua_settable(L, LUA_GLOBALSINDEX); + } + lua_pushstring(L, "mime"); + lua_newtable(L); + luaL_openlib(L, NULL, func, 0); + lua_settable(L, -3); + lua_pop(L, 1); + /* initialize lookup tables */ + qpfill(qpclass, qpunbase); + b64fill(b64unbase); +} + +/*=========================================================================*\ +* Global Lua functions +\*=========================================================================*/ +/*-------------------------------------------------------------------------*\ +* Incrementaly breaks a string into lines +* A, n = fmt(B, length, left) +* A is a copy of B, broken into lines of at most 'length' bytes. +* Left is how many bytes are left in the first line of B. 'n' is the number +* of bytes left in the last line of A. +\*-------------------------------------------------------------------------*/ +static int mime_global_fmt(lua_State *L) +{ + size_t size = 0; + const UC *input = lua_isnil(L, 1)? NULL: luaL_checklstring(L, 1, &size); + const UC *last = input + size; + int length = (int) luaL_checknumber(L, 2); + int left = (int) luaL_optnumber(L, 3, length); + const UC *marker = luaL_optstring(L, 4, CRLF); + luaL_Buffer buffer; + luaL_buffinit(L, &buffer); + while (input < last) { + luaL_putchar(&buffer, *input++); + if (--left <= 0) { + luaL_addstring(&buffer, marker); + left = length; + } + } + if (!input && left < length) { + luaL_addstring(&buffer, marker); + left = length; + } + luaL_pushresult(&buffer); + lua_pushnumber(L, left); + return 2; +} + +/*-------------------------------------------------------------------------*\ +* Fill base64 decode map. +\*-------------------------------------------------------------------------*/ +static void b64fill(UC *b64unbase) +{ + int i; + for (i = 0; i < 255; i++) b64unbase[i] = 255; + for (i = 0; i < 64; i++) b64unbase[b64base[i]] = i; + b64unbase['='] = 0; +} + +/*-------------------------------------------------------------------------*\ +* Acumulates bytes in input buffer until 3 bytes are available. +* Translate the 3 bytes into Base64 form and append to buffer. +* Returns new number of bytes in buffer. +\*-------------------------------------------------------------------------*/ +static size_t b64encode(UC c, UC *input, size_t size, + luaL_Buffer *buffer) +{ + input[size++] = c; + if (size == 3) { + UC code[4]; + unsigned long value = 0; + value += input[0]; value <<= 8; + value += input[1]; value <<= 8; + value += input[2]; + code[3] = b64base[value & 0x3f]; value >>= 6; + code[2] = b64base[value & 0x3f]; value >>= 6; + code[1] = b64base[value & 0x3f]; value >>= 6; + code[0] = b64base[value]; + luaL_addlstring(buffer, code, 4); + size = 0; + } + return size; +} + +/*-------------------------------------------------------------------------*\ +* Encodes the Base64 last 1 or 2 bytes and adds padding '=' +* Result, if any, is appended to buffer. +* Returns 0. +\*-------------------------------------------------------------------------*/ +static size_t b64pad(const UC *input, size_t size, + luaL_Buffer *buffer) +{ + unsigned long value = 0; + UC code[4] = "===="; + switch (size) { + case 1: + value = input[0] << 4; + code[1] = b64base[value & 0x3f]; value >>= 6; + code[0] = b64base[value]; + luaL_addlstring(buffer, code, 4); + break; + case 2: + value = input[0]; value <<= 8; + value |= input[1]; value <<= 2; + code[2] = b64base[value & 0x3f]; value >>= 6; + code[1] = b64base[value & 0x3f]; value >>= 6; + code[0] = b64base[value]; + luaL_addlstring(buffer, code, 4); + break; + case 0: /* fall through */ + default: + break; + } + return 0; +} + +/*-------------------------------------------------------------------------*\ +* Acumulates bytes in input buffer until 4 bytes are available. +* Translate the 4 bytes from Base64 form and append to buffer. +* Returns new number of bytes in buffer. +\*-------------------------------------------------------------------------*/ +static size_t b64decode(UC c, UC *input, size_t size, + luaL_Buffer *buffer) +{ + + /* ignore invalid characters */ + if (b64unbase[c] > 64) return size; + input[size++] = c; + /* decode atom */ + if (size == 4) { + UC decoded[3]; + int valid, value = 0; + value = b64unbase[input[0]]; value <<= 6; + value |= b64unbase[input[1]]; value <<= 6; + value |= b64unbase[input[2]]; value <<= 6; + value |= b64unbase[input[3]]; + decoded[2] = (UC) (value & 0xff); value >>= 8; + decoded[1] = (UC) (value & 0xff); value >>= 8; + decoded[0] = (UC) value; + /* take care of paddding */ + valid = (input[2] == '=') ? 1 : (input[3] == '=') ? 2 : 3; + luaL_addlstring(buffer, decoded, valid); + return 0; + /* need more data */ + } else return size; +} + +/*-------------------------------------------------------------------------*\ +* Incrementally applies the Base64 transfer content encoding to a string +* A, B = b64(C, D) +* A is the encoded version of the largest prefix of C .. D that is +* divisible by 3. B has the remaining bytes of C .. D, *without* encoding. +* The easiest thing would be to concatenate the two strings and +* encode the result, but we can't afford that or Lua would dupplicate +* every chunk we received. +\*-------------------------------------------------------------------------*/ +static int mime_global_b64(lua_State *L) +{ + UC atom[3]; + size_t isize = 0, asize = 0; + const UC *input = luaL_checklstring(L, 1, &isize); + const UC *last = input + isize; + luaL_Buffer buffer; + luaL_buffinit(L, &buffer); + while (input < last) + asize = b64encode(*input++, atom, asize, &buffer); + input = luaL_optlstring(L, 2, NULL, &isize); + if (input) { + last = input + isize; + while (input < last) + asize = b64encode(*input++, atom, asize, &buffer); + } else + asize = b64pad(atom, asize, &buffer); + luaL_pushresult(&buffer); + lua_pushlstring(L, atom, asize); + return 2; +} + +/*-------------------------------------------------------------------------*\ +* Incrementally removes the Base64 transfer content encoding from a string +* A, B = b64(C, D) +* A is the encoded version of the largest prefix of C .. D that is +* divisible by 4. B has the remaining bytes of C .. D, *without* encoding. +\*-------------------------------------------------------------------------*/ +static int mime_global_unb64(lua_State *L) +{ + UC atom[4]; + size_t isize = 0, asize = 0; + const UC *input = luaL_checklstring(L, 1, &isize); + const UC *last = input + isize; + luaL_Buffer buffer; + luaL_buffinit(L, &buffer); + while (input < last) + asize = b64decode(*input++, atom, asize, &buffer); + input = luaL_optlstring(L, 2, NULL, &isize); + if (input) { + last = input + isize; + while (input < last) + asize = b64decode(*input++, atom, asize, &buffer); + } + luaL_pushresult(&buffer); + lua_pushlstring(L, atom, asize); + return 2; +} + +/*-------------------------------------------------------------------------*\ +* Quoted-printable encoding scheme +* all (except CRLF in text) can be =XX +* CLRL in not text must be =XX=XX +* 33 through 60 inclusive can be plain +* 62 through 120 inclusive can be plain +* 9 and 32 can be plain, unless in the end of a line, where must be =XX +* encoded lines must be no longer than 76 not counting CRLF +* soft line-break are =CRLF +* !"#$@[\]^`{|}~ should be =XX for EBCDIC compatibility +* To encode one byte, we need to see the next two. +* Worst case is when we see a space, and wonder if a CRLF is comming +\*-------------------------------------------------------------------------*/ +/*-------------------------------------------------------------------------*\ +* Split quoted-printable characters into classes +* Precompute reverse map for encoding +\*-------------------------------------------------------------------------*/ +static void qpfill(UC *qpclass, UC *qpunbase) +{ + int i; + for (i = 0; i < 256; i++) qpclass[i] = QP_QUOTED; + for (i = 33; i <= 60; i++) qpclass[i] = QP_PLAIN; + for (i = 62; i <= 120; i++) qpclass[i] = QP_PLAIN; + qpclass[HT] = QP_IF_LAST; qpclass[SP] = QP_IF_LAST; + qpclass['!'] = QP_QUOTED; qpclass['"'] = QP_QUOTED; + qpclass['#'] = QP_QUOTED; qpclass['$'] = QP_QUOTED; + qpclass['@'] = QP_QUOTED; qpclass['['] = QP_QUOTED; + qpclass['\\'] = QP_QUOTED; qpclass[']'] = QP_QUOTED; + qpclass['^'] = QP_QUOTED; qpclass['`'] = QP_QUOTED; + qpclass['{'] = QP_QUOTED; qpclass['|'] = QP_QUOTED; + qpclass['}'] = QP_QUOTED; qpclass['~'] = QP_QUOTED; + qpclass['}'] = QP_QUOTED; qpclass[CR] = QP_CR; + for (i = 0; i < 256; i++) qpunbase[i] = 255; + qpunbase['0'] = 0; qpunbase['1'] = 1; qpunbase['2'] = 2; + qpunbase['3'] = 3; qpunbase['4'] = 4; qpunbase['5'] = 5; + qpunbase['6'] = 6; qpunbase['7'] = 7; qpunbase['8'] = 8; + qpunbase['9'] = 9; qpunbase['A'] = 10; qpunbase['a'] = 10; + qpunbase['B'] = 11; qpunbase['b'] = 11; qpunbase['C'] = 12; + qpunbase['c'] = 12; qpunbase['D'] = 13; qpunbase['d'] = 13; + qpunbase['E'] = 14; qpunbase['e'] = 14; qpunbase['F'] = 15; + qpunbase['f'] = 15; +} + +/*-------------------------------------------------------------------------*\ +* Output one character in form =XX +\*-------------------------------------------------------------------------*/ +static void qpquote(UC c, luaL_Buffer *buffer) +{ + luaL_putchar(buffer, '='); + luaL_putchar(buffer, qpbase[c >> 4]); + luaL_putchar(buffer, qpbase[c & 0x0F]); +} + +/*-------------------------------------------------------------------------*\ +* Accumulate characters until we are sure about how to deal with them. +* Once we are sure, output the to the buffer, in the correct form. +\*-------------------------------------------------------------------------*/ +static size_t qpencode(UC c, UC *input, size_t size, + const UC *marker, luaL_Buffer *buffer) +{ + input[size++] = c; + /* deal with all characters we can have */ + while (size > 0) { + switch (qpclass[input[0]]) { + /* might be the CR of a CRLF sequence */ + case QP_CR: + if (size < 2) return size; + if (input[1] == LF) { + luaL_addstring(buffer, marker); + return 0; + } else qpquote(input[0], buffer); + break; + /* might be a space and that has to be quoted if last in line */ + case QP_IF_LAST: + if (size < 3) return size; + /* if it is the last, quote it and we are done */ + if (input[1] == CR && input[2] == LF) { + qpquote(input[0], buffer); + luaL_addstring(buffer, marker); + return 0; + } else luaL_putchar(buffer, input[0]); + break; + /* might have to be quoted always */ + case QP_QUOTED: + qpquote(input[0], buffer); + break; + /* might never have to be quoted */ + default: + luaL_putchar(buffer, input[0]); + break; + } + input[0] = input[1]; input[1] = input[2]; + size--; + } + return 0; +} + +/*-------------------------------------------------------------------------*\ +* Deal with the final characters +\*-------------------------------------------------------------------------*/ +static void qppad(UC *input, size_t size, luaL_Buffer *buffer) +{ + size_t i; + for (i = 0; i < size; i++) { + if (qpclass[input[i]] == QP_PLAIN) luaL_putchar(buffer, input[i]); + else qpquote(input[i], buffer); + } + luaL_addstring(buffer, EQCRLF); +} + +/*-------------------------------------------------------------------------*\ +* Incrementally converts a string to quoted-printable +* A, B = qp(C, D, marker) +* Crlf is the text to be used to replace CRLF sequences found in A. +* A is the encoded version of the largest prefix of C .. D that +* can be encoded without doubts. +* B has the remaining bytes of C .. D, *without* encoding. +\*-------------------------------------------------------------------------*/ +static int mime_global_qp(lua_State *L) +{ + + size_t asize = 0, isize = 0; + UC atom[3]; + const UC *input = lua_isnil(L, 1) ? NULL: luaL_checklstring(L, 1, &isize); + const UC *last = input + isize; + const UC *marker = luaL_optstring(L, 3, CRLF); + luaL_Buffer buffer; + luaL_buffinit(L, &buffer); + while (input < last) + asize = qpencode(*input++, atom, asize, marker, &buffer); + input = luaL_optlstring(L, 2, NULL, &isize); + if (input) { + last = input + isize; + while (input < last) + asize = qpencode(*input++, atom, asize, marker, &buffer); + } else qppad(atom, asize, &buffer); + luaL_pushresult(&buffer); + lua_pushlstring(L, atom, asize); + return 2; +} + +/*-------------------------------------------------------------------------*\ +* Accumulate characters until we are sure about how to deal with them. +* Once we are sure, output the to the buffer, in the correct form. +\*-------------------------------------------------------------------------*/ +static size_t qpdecode(UC c, UC *input, size_t size, + luaL_Buffer *buffer) +{ + input[size++] = c; + /* deal with all characters we can deal */ + while (size > 0) { + int c, d; + switch (input[0]) { + /* if we have an escape character */ + case '=': + if (size < 3) return size; + /* eliminate soft line break */ + if (input[1] == CR && input[2] == LF) return 0; + /* decode quoted representation */ + c = qpunbase[input[1]]; d = qpunbase[input[2]]; + /* if it is an invalid, do not decode */ + if (c > 15 || d > 15) luaL_addlstring(buffer, input, 3); + else luaL_putchar(buffer, (c << 4) + d); + return 0; + case CR: + if (size < 2) return size; + if (input[1] == LF) luaL_addlstring(buffer, input, 2); + return 0; + default: + if (input[0] == HT || (input[0] > 31 && input[0] < 127)) + luaL_putchar(buffer, input[0]); + return 0; + } + input[0] = input[1]; input[1] = input[2]; + size--; + } + return 0; +} + +/*-------------------------------------------------------------------------*\ +* Incrementally decodes a string in quoted-printable +* A, B = qp(C, D) +* A is the decoded version of the largest prefix of C .. D that +* can be decoded without doubts. +* B has the remaining bytes of C .. D, *without* decoding. +\*-------------------------------------------------------------------------*/ +static int mime_global_unqp(lua_State *L) +{ + + size_t asize = 0, isize = 0; + UC atom[3]; + const UC *input = lua_isnil(L, 1) ? NULL: luaL_checklstring(L, 1, &isize); + const UC *last = input + isize; + luaL_Buffer buffer; + luaL_buffinit(L, &buffer); + while (input < last) + asize = qpdecode(*input++, atom, asize, &buffer); + input = luaL_optlstring(L, 2, NULL, &isize); + if (input) { + last = input + isize; + while (input < last) + asize = qpdecode(*input++, atom, asize, &buffer); + } + luaL_pushresult(&buffer); + lua_pushlstring(L, atom, asize); + return 2; +} + +/*-------------------------------------------------------------------------*\ +* Incrementally breaks a quoted-printed string into lines +* A, n = qpfmt(B, length, left) +* A is a copy of B, broken into lines of at most 'length' bytes. +* Left is how many bytes are left in the first line of B. 'n' is the number +* of bytes left in the last line of A. +* There are two complications: lines can't be broken in the middle +* of an encoded =XX, and there might be line breaks already +\*-------------------------------------------------------------------------*/ +static int mime_global_qpfmt(lua_State *L) +{ + size_t size = 0; + const UC *input = lua_isnil(L, 1)? NULL: luaL_checklstring(L, 1, &size); + const UC *last = input + size; + int length = (int) luaL_checknumber(L, 2); + int left = (int) luaL_optnumber(L, 3, length); + luaL_Buffer buffer; + luaL_buffinit(L, &buffer); + while (input < last) { + left--; + switch (*input) { + case '=': + /* if there's no room in this line for the quoted char, + * output a soft line break now */ + if (left <= 3) { + luaL_addstring(&buffer, EQCRLF); + left = length; + } + break; + /* \r\n starts a new line */ + case CR: + break; + case LF: + left = length; + break; + default: + /* if in last column, output a soft line break */ + if (left <= 1) { + luaL_addstring(&buffer, EQCRLF); + left = length; + } + } + luaL_putchar(&buffer, *input); + input++; + } + if (!input && left < length) { + luaL_addstring(&buffer, EQCRLF); + left = length; + } + luaL_pushresult(&buffer); + lua_pushnumber(L, left); + return 2; +} + +/*-------------------------------------------------------------------------*\ +* Here is what we do: \n, \r and \f are considered candidates for line +* break. We issue *one* new line marker if any of them is seen alone, or +* followed by a different one. That is, \n\n, \r\r and \f\f will issue two +* end of line markers each, but \r\n, \n\r, \r\f etc will only issue *one* +* marker. This covers Mac OS, Mac OS X, VMS, Unix and DOS, as well as +* probably other more obscure conventions. +\*-------------------------------------------------------------------------*/ +#define eolcandidate(c) (c == CR || c == LF) +static size_t eolconvert(UC c, UC *input, size_t size, + const UC *marker, luaL_Buffer *buffer) +{ + input[size++] = c; + /* deal with all characters we can deal */ + if (eolcandidate(input[0])) { + if (size < 2) return size; + luaL_addstring(buffer, marker); + if (eolcandidate(input[1])) { + if (input[0] == input[1]) luaL_addstring(buffer, marker); + } else luaL_putchar(buffer, input[1]); + return 0; + } else { + luaL_putchar(buffer, input[0]); + return 0; + } +} + +/*-------------------------------------------------------------------------*\ +* Converts a string to uniform EOL convention. +* A, B = eol(C, D, marker) +* A is the converted version of the largest prefix of C .. D that +* can be converted without doubts. +* B has the remaining bytes of C .. D, *without* convertion. +\*-------------------------------------------------------------------------*/ +static int mime_global_eol(lua_State *L) +{ + size_t asize = 0, isize = 0; + UC atom[2]; + const UC *input = lua_isnil(L, 1)? NULL: luaL_checklstring(L, 1, &isize); + const UC *last = input + isize; + const UC *marker = luaL_optstring(L, 3, CRLF); + luaL_Buffer buffer; + luaL_buffinit(L, &buffer); + while (input < last) + asize = eolconvert(*input++, atom, asize, marker, &buffer); + input = luaL_optlstring(L, 2, NULL, &isize); + if (input) { + last = input + isize; + while (input < last) + asize = eolconvert(*input++, atom, asize, marker, &buffer); + /* if there is something in atom, it's one character, and it + * is a candidate. so we output a new line */ + } else if (asize > 0) luaL_addstring(&buffer, marker); + luaL_pushresult(&buffer); + lua_pushlstring(L, atom, asize); + return 2; +} diff --git a/src/mime.h b/src/mime.h new file mode 100644 index 0000000..8323783 --- /dev/null +++ b/src/mime.h @@ -0,0 +1,17 @@ +#ifndef MIME_H +#define MIME_H +/*=========================================================================*\ +* Mime support functions +* LuaSocket toolkit +* +* This module provides functions to implement transfer content encodings +* and formatting conforming to RFC 2045. It is used by mime.lua, which +* provide a higher level interface to this functionality. +* +* RCS ID: $Id$ +\*=========================================================================*/ +#include + +void mime_open(lua_State *L); + +#endif /* MIME_H */ diff --git a/src/mime.lua b/src/mime.lua new file mode 100644 index 0000000..86b3af2 --- /dev/null +++ b/src/mime.lua @@ -0,0 +1,104 @@ +-- make sure LuaSocket is loaded +if not LUASOCKET_LIBNAME then error('module requires LuaSocket') end +-- get LuaSocket namespace +local socket = _G[LUASOCKET_LIBNAME] +if not socket then error('module requires LuaSocket') end +-- create code namespace inside LuaSocket namespace +local mime = socket.mime or {} +socket.mime = mime +-- make all module globals fall into mime namespace +setmetatable(mime, { __index = _G }) +setfenv(1, mime) + +base64 = {} +qprint = {} + +function base64.encode() + local unfinished = "" + return function(chunk) + local done + done, unfinished = b64(unfinished, chunk) + return done + end +end + +function base64.decode() + local unfinished = "" + return function(chunk) + local done + done, unfinished = unb64(unfinished, chunk) + return done + end +end + +function qprint.encode(mode) + mode = (mode == "binary") and "=0D=0A" or "\13\10" + local unfinished = "" + return function(chunk) + local done + done, unfinished = qp(unfinished, chunk, mode) + return done + end +end + +function qprint.decode() + local unfinished = "" + return function(chunk) + local done + done, unfinished = unqp(unfinished, chunk) + return done + end +end + +function split(length, marker) + length = length or 76 + local left = length + return function(chunk) + local done + done, left = fmt(chunk, length, left, marker) + return done + end +end + +function qprint.split(length) + length = length or 76 + local left = length + return function(chunk) + local done + done, left = qpfmt(chunk, length, left) + return done + end +end + +function canonic(marker) + local unfinished = "" + return function(chunk) + local done + done, unfinished = eol(unfinished, chunk, marker) + return done + end +end + +function chain(...) + local layers = table.getn(arg) + return function (chunk) + if not chunk then + local parts = {} + for i = 1, layers do + for j = i, layers do + chunk = arg[j](chunk) + end + table.insert(parts, chunk) + chunk = nil + end + return table.concat(parts) + else + for j = 1, layers do + chunk = arg[j](chunk) + end + return chunk + end + end +end + +return code diff --git a/src/tcp.c b/src/tcp.c index ce2ae17..b4b9fd9 100644 --- a/src/tcp.c +++ b/src/tcp.c @@ -238,7 +238,7 @@ static int meth_bind(lua_State *L) return 2; } /* turn master object into a server object if there was a listen */ - if (backlog > 0) aux_setclass(L, "tcp{server}", 1); + if (backlog >= 0) aux_setclass(L, "tcp{server}", 1); lua_pushnumber(L, 1); return 1; } diff --git a/src/url.lua b/src/url.lua index 27e7928..ab3a922 100644 --- a/src/url.lua +++ b/src/url.lua @@ -6,9 +6,102 @@ -- RCS ID: $Id$ ---------------------------------------------------------------------------- -local Public, Private = {}, {} -local socket = _G[LUASOCKET_LIBNAME] -- get LuaSocket namespace -socket.url = Public +-- make sure LuaSocket is loaded +if not LUASOCKET_LIBNAME then error('module requires LuaSocket') end +-- get LuaSocket namespace +local socket = _G[LUASOCKET_LIBNAME] +if not socket then error('module requires LuaSocket') end +-- create smtp namespace inside LuaSocket namespace +local url = {} +socket.url = url +-- make all module globals fall into smtp namespace +setmetatable(url, { __index = _G }) +setfenv(1, url) + +----------------------------------------------------------------------------- +-- Encodes a string into its escaped hexadecimal representation +-- Input +-- s: binary string to be encoded +-- Returns +-- escaped representation of string binary +----------------------------------------------------------------------------- +function escape(s) + return string.gsub(s, "(.)", function(c) + return string.format("%%%02x", string.byte(c)) + end) +end + +----------------------------------------------------------------------------- +-- Protects a path segment, to prevent it from interfering with the +-- url parsing. +-- Input +-- s: binary string to be encoded +-- Returns +-- escaped representation of string binary +----------------------------------------------------------------------------- +local function make_set(t) + local s = {} + for i = 1, table.getn(t) do + s[t[i]] = 1 + end + return s +end + +-- these are allowed withing a path segment, along with alphanum +-- other characters must be escaped +local segment_set = make_set { + "-", "_", ".", "!", "~", "*", "'", "(", + ")", ":", "@", "&", "=", "+", "$", ",", +} + +local function protect_segment(s) + return string.gsub(s, "(%W)", function (c) + if segment_set[c] then return c + else return escape(c) end + end) +end + +----------------------------------------------------------------------------- +-- Encodes a string into its escaped hexadecimal representation +-- Input +-- s: binary string to be encoded +-- Returns +-- escaped representation of string binary +----------------------------------------------------------------------------- +function unescape(s) + return string.gsub(s, "%%(%x%x)", function(hex) + return string.char(tonumber(hex, 16)) + end) +end + +----------------------------------------------------------------------------- +-- Builds a path from a base path and a relative path +-- Input +-- base_path +-- relative_path +-- Returns +-- corresponding absolute path +----------------------------------------------------------------------------- +local function absolute_path(base_path, relative_path) + if string.sub(relative_path, 1, 1) == "/" then return relative_path end + local path = string.gsub(base_path, "[^/]*$", "") + path = path .. relative_path + path = string.gsub(path, "([^/]*%./)", function (s) + if s ~= "./" then return s else return "" end + end) + path = string.gsub(path, "/%.$", "/") + local reduced + while reduced ~= path do + reduced = path + path = string.gsub(reduced, "([^/]*/%.%./)", function (s) + if s ~= "../../" then return "" else return s end + end) + end + path = string.gsub(reduced, "([^/]*/%.%.)$", function (s) + if s ~= "../.." then return "" else return s end + end) + return path +end ----------------------------------------------------------------------------- -- Parses a url and returns a table with all its parts according to RFC 2396 @@ -28,7 +121,7 @@ socket.url = Public -- Obs: -- the leading '/' in {/} is considered part of ----------------------------------------------------------------------------- -function Public.parse(url, default) +function parse(url, default) -- initialize default parameters local parsed = default or {} -- empty url is parsed to nil @@ -66,11 +159,11 @@ end -- Rebuilds a parsed URL from its components. -- Components are protected if any reserved or unallowed characters are found -- Input --- parsed: parsed URL, as returned by Public.parse +-- parsed: parsed URL, as returned by parse -- Returns -- a stringing with the corresponding URL ----------------------------------------------------------------------------- -function Public.build(parsed) +function build(parsed) local url = parsed.path or "" if parsed.params then url = url .. ";" .. parsed.params end if parsed.query then url = url .. "?" .. parsed.query end @@ -102,9 +195,9 @@ end -- Returns -- corresponding absolute url ----------------------------------------------------------------------------- -function Public.absolute(base_url, relative_url) - local base = Public.parse(base_url) - local relative = Public.parse(relative_url) +function absolute(base_url, relative_url) + local base = parse(base_url) + local relative = parse(relative_url) if not base then return relative_url elseif not relative then return base_url elseif relative.scheme then return relative_url @@ -121,10 +214,10 @@ function Public.absolute(base_url, relative_url) end end else - relative.path = Private.absolute_path(base.path,relative.path) + relative.path = absolute_path(base.path,relative.path) end end - return Public.build(relative) + return build(relative) end end @@ -135,13 +228,13 @@ end -- Returns -- segment: a table with one entry per segment ----------------------------------------------------------------------------- -function Public.parse_path(path) +function parse_path(path) local parsed = {} path = path or "" path = string.gsub(path, "%s", "") string.gsub(path, "([^/]+)", function (s) table.insert(parsed, s) end) for i = 1, table.getn(parsed) do - parsed[i] = socket.code.unescape(parsed[i]) + parsed[i] = unescape(parsed[i]) end if string.sub(path, 1, 1) == "/" then parsed.is_absolute = 1 end if string.sub(path, -1, -1) == "/" then parsed.is_directory = 1 end @@ -154,9 +247,9 @@ end -- parsed: path segments -- unsafe: if true, segments are not protected before path is built -- Returns --- path: correspondin path stringing +-- path: corresponding path stringing ----------------------------------------------------------------------------- -function Public.build_path(parsed, unsafe) +function build_path(parsed, unsafe) local path = "" local n = table.getn(parsed) if unsafe then @@ -170,66 +263,14 @@ function Public.build_path(parsed, unsafe) end else for i = 1, n-1 do - path = path .. Private.protect_segment(parsed[i]) + path = path .. protect_segment(parsed[i]) path = path .. "/" end if n > 0 then - path = path .. Private.protect_segment(parsed[n]) + path = path .. protect_segment(parsed[n]) if parsed.is_directory then path = path .. "/" end end end if parsed.is_absolute then path = "/" .. path end return path end - -function Private.make_set(t) - local s = {} - for i = 1, table.getn(t) do - s[t[i]] = 1 - end - return s -end - --- these are allowed withing a path segment, along with alphanum --- other characters must be escaped -Private.segment_set = Private.make_set { - "-", "_", ".", "!", "~", "*", "'", "(", - ")", ":", "@", "&", "=", "+", "$", ",", -} - -function Private.protect_segment(s) - local segment_set = Private.segment_set - return string.gsub(s, "(%W)", function (c) - if segment_set[c] then return c - else return socket.code.escape(c) end - end) -end - ------------------------------------------------------------------------------ --- Builds a path from a base path and a relative path --- Input --- base_path --- relative_path --- Returns --- corresponding absolute path ------------------------------------------------------------------------------ -function Private.absolute_path(base_path, relative_path) - if string.sub(relative_path, 1, 1) == "/" then return relative_path end - local path = string.gsub(base_path, "[^/]*$", "") - path = path .. relative_path - path = string.gsub(path, "([^/]*%./)", function (s) - if s ~= "./" then return s else return "" end - end) - path = string.gsub(path, "/%.$", "/") - local reduced - while reduced ~= path do - reduced = path - path = string.gsub(reduced, "([^/]*/%.%./)", function (s) - if s ~= "../../" then return "" else return s end - end) - end - path = string.gsub(reduced, "([^/]*/%.%.)$", function (s) - if s ~= "../.." then return "" else return s end - end) - return path -end diff --git a/test/httptest.lua b/test/httptest.lua index 3d0db87..dc90741 100644 --- a/test/httptest.lua +++ b/test/httptest.lua @@ -5,14 +5,14 @@ -- needs "AllowOverride AuthConfig" on /home/c/diego/tec/luasocket/test/auth dofile("noglobals.lua") -local host, proxyh, proxyp, request, response +local host, proxyhost, proxyport, request, response local ignore, expect, index, prefix, cgiprefix local t = socket.time() host = host or "diego.princeton.edu" -proxyh = proxyh or "localhost" -proxyp = proxyp or 3128 +proxyhost = proxyhost or "localhost" +proxyport = proxyport or 3128 prefix = prefix or "/luasocket-test" cgiprefix = cgiprefix or "/luasocket-test-cgi" @@ -129,8 +129,8 @@ request = { method = "POST", body = index, headers = { ["content-length"] = string.len(index) }, - port = proxyp, - host = proxyh + proxyport = proxyport, + proxyhost = proxyhost } expect = { body = index, @@ -170,8 +170,8 @@ check_request(request, expect, ignore) io.write("testing proxy with redirection: ") request = { url = "http://" .. host .. prefix, - host = proxyh, - port = proxyp + proxyhost = proxyhost, + proxyport = proxyport } expect = { body = index, @@ -267,7 +267,7 @@ io.write("testing manual basic auth: ") request = { url = "http://" .. host .. prefix .. "/auth/index.html", headers = { - authorization = "Basic " .. (socket.code.b64("luasocket:password")) + authorization = "Basic " .. (socket.mime.b64("luasocket:password")) } } expect = { diff --git a/test/mimetest.lua b/test/mimetest.lua new file mode 100644 index 0000000..5485db1 --- /dev/null +++ b/test/mimetest.lua @@ -0,0 +1,236 @@ +dofile("noglobals.lua") + +local qptest = "qptest.bin" +local eqptest = "qptest.bin2" +local dqptest = "qptest.bin3" + +local b64test = "luasocket" +local eb64test = "b64test.bin" +local db64test = "b64test.bin2" + +-- from Machado de Assis, "A Mão e a Rosa" +local mao = [[ + Cursavam estes dois moços a academia de S. Paulo, estando + Luís Alves no quarto ano e Estêvão no terceiro. + Conheceram-se na academia, e ficaram amigos íntimos, tanto + quanto podiam sê-lo dois espíritos diferentes, ou talvez por + isso mesmo que o eram. Estêvão, dotado de extrema + sensibilidade, e não menor fraqueza de ânimo, afetuoso e + bom, não daquela bondade varonil, que é apanágio de uma alma + forte, mas dessa outra bondade mole e de cera, que vai à + mercê de todas as circunstâncias, tinha, além de tudo isso, + o infortúnio de trazer ainda sobre o nariz os óculos + cor-de-rosa de suas virginais ilusões. Luís Alves via bem + com os olhos da cara. Não era mau rapaz, mas tinha o seu + grão de egoísmo, e se não era incapaz de afeições, sabia + regê-las, moderá-las, e sobretudo guiá-las ao seu próprio + interesse. Entre estes dois homens travara-se amizade + íntima, nascida para um na simpatia, para outro no costume. + Eram eles os naturais confidentes um do outro, com a + diferença que Luís Alves dava menos do que recebia, e, ainda + assim, nem tudo o que dava exprimia grande confiança. +]] + +local fail = function(s) + s = s or "failed" + assert(nil, s) +end + +local readfile = function(name) + local f = io.open(name, "r") + if not f then return nil end + local s = f:read("*a") + f:close() + return s +end + +local function transform(input, output, filter) + local fi, err = io.open(input, "rb") + if not fi then fail(err) end + local fo, err = io.open(output, "wb") + if not fo then fail(err) end + while 1 do + local chunk = fi:read(math.random(0, 256)) + fo:write(filter(chunk)) + if not chunk then break end + end + fi:close() + fo:close() +end + +local function compare(input, output) + local original = readfile(input) + local recovered = readfile(output) + if original ~= recovered then fail("recovering failed") + else print("ok") end +end + +local function encode_qptest(mode) + local encode = socket.mime.qprint.encode(mode) + local split = socket.mime.qprint.split() + local chain = socket.mime.chain(encode, split) + transform(qptest, eqptest, chain) +end + +local function compare_qptest() + compare(qptest, dqptest) +end + +local function decode_qptest() + local decode = socket.mime.qprint.decode() + transform(eqptest, dqptest, decode) +end + +local function create_qptest() + local f, err = io.open(qptest, "wb") + if not f then fail(err) end + -- try all characters + for i = 0, 255 do + f:write(string.char(i)) + end + -- try all characters and different line sizes + for i = 0, 255 do + for j = 0, i do + f:write(string.char(i)) + end + f:write("\r\n") + end + -- test latin text + f:write(mao) + -- force soft line breaks and treatment of space/tab in end of line + local tab + f:write(string.gsub(mao, "(%s)", function(c) + if tab then + tab = nil + return "\t" + else + tab = 1 + return " " + end + end)) + -- test crazy end of line conventions + local eol = { "\r\n", "\r", "\n", "\n\r" } + local which = 0 + f:write(string.gsub(mao, "(\n)", function(c) + which = which + 1 + if which > 4 then which = 1 end + return eol[which] + end)) + for i = 1, 4 do + for j = 1, 4 do + f:write(eol[i]) + f:write(eol[j]) + end + end + -- try long spaced and tabbed lines + f:write("\r\n") + for i = 0, 255 do + f:write(string.char(9)) + end + f:write("\r\n") + for i = 0, 255 do + f:write(' ') + end + f:write("\r\n") + for i = 0, 255 do + f:write(string.char(9),' ') + end + f:write("\r\n") + for i = 0, 255 do + f:write(' ',string.char(32)) + end + f:write("\r\n") + + f:close() +end + +local function cleanup_qptest() + os.remove(qptest) + os.remove(eqptest) + os.remove(dqptest) +end + +local function encode_b64test() + local e1 = socket.mime.base64.encode() + local e2 = socket.mime.base64.encode() + local e3 = socket.mime.base64.encode() + local e4 = socket.mime.base64.encode() + local sp4 = socket.mime.split() + local sp3 = socket.mime.split(59) + local sp2 = socket.mime.split(30) + local sp1 = socket.mime.split(27) + local chain = socket.mime.chain(e1, sp1, e2, sp2, e3, sp3, e4, sp4) + transform(b64test, eb64test, chain) +end + +local function decode_b64test() + local d1 = socket.mime.base64.decode() + local d2 = socket.mime.base64.decode() + local d3 = socket.mime.base64.decode() + local d4 = socket.mime.base64.decode() + local chain = socket.mime.chain(d1, d2, d3, d4) + transform(eb64test, db64test, chain) +end + +local function cleanup_b64test() + os.remove(eb64test) + os.remove(db64test) +end + +local function compare_b64test() + compare(b64test, db64test) +end + +local function padcheck(original, encoded) + local e = (socket.mime.b64(original)) + local d = (socket.mime.unb64(encoded)) + if e ~= encoded then fail("encoding failed") end + if d ~= original then fail("decoding failed") end +end + +local function chunkcheck(original, encoded) + local len = string.len(original) + for i = 0, len do + local a = string.sub(original, 1, i) + local b = string.sub(original, i+1) + local e, r = socket.mime.b64(a, b) + local f = (socket.mime.b64(r)) + if (e .. f ~= encoded) then fail(e .. f) end + end +end + +local function padding_b64test() + padcheck("a", "YQ==") + padcheck("ab", "YWI=") + padcheck("abc", "YWJj") + padcheck("abcd", "YWJjZA==") + padcheck("abcde", "YWJjZGU=") + padcheck("abcdef", "YWJjZGVm") + padcheck("abcdefg", "YWJjZGVmZw==") + padcheck("abcdefgh", "YWJjZGVmZ2g=") + padcheck("abcdefghi", "YWJjZGVmZ2hp") + padcheck("abcdefghij", "YWJjZGVmZ2hpag==") + chunkcheck("abcdefgh", "YWJjZGVmZ2g=") + chunkcheck("abcdefghi", "YWJjZGVmZ2hp") + chunkcheck("abcdefghij", "YWJjZGVmZ2hpag==") + print("ok") +end + +local t = socket.time() + +create_qptest() +encode_qptest() +decode_qptest() +compare_qptest() +encode_qptest("binary") +decode_qptest() +compare_qptest() +cleanup_qptest() + +encode_b64test() +decode_b64test() +compare_b64test() +cleanup_b64test() +padding_b64test() + +print(string.format("done in %.2fs", socket.time() - t))