From bcc0c2a9f0be2ca796ef5206a78e283fe15e6186 Mon Sep 17 00:00:00 2001 From: Diego Nehab Date: Tue, 16 Mar 2004 06:42:53 +0000 Subject: [PATCH] New filter scheme. ltn12 and mime updated. smtp/ftp broken. --- TODO | 5 + etc/b64.lua | 17 +-- etc/get.lua | 47 +++--- src/auxiliar.c | 11 -- src/auxiliar.h | 1 - src/ftp.lua | 366 +++++++++++++-------------------------------- src/http.lua | 219 +++++++++++++-------------- src/ltn12.lua | 75 +++++++--- src/luasocket.c | 6 - src/mime.c | 137 +++++++++++------ src/mime.lua | 2 +- src/smtp.lua | 2 +- src/tp.lua | 111 ++++++++++++++ src/wsocket.c | 4 +- test/httptest.lua | 45 +++--- test/mimetest.lua | 31 ++-- test/stufftest.lua | 19 +++ 17 files changed, 568 insertions(+), 530 deletions(-) create mode 100644 src/tp.lua create mode 100644 test/stufftest.lua diff --git a/TODO b/TODO index 7479cfc..3f1c71b 100644 --- a/TODO +++ b/TODO @@ -19,6 +19,11 @@ * Separar as classes em arquivos * Retorno de sendto em datagram sockets pode ser refused +colocar um userdata com gc metamethod pra chamar sock_close (WSAClose); +sources ans sinks are always simple in http and ftp and smtp +unify backbone of smtp and ftp +expose encode/decode tables to provide extensibility for mime module +use coroutines instead of fancy filters unify filter and send/receive callback. new sink/source/pump idea. get rid of aux_optlstring wrap sink and sources with a function that performs the replacement diff --git a/etc/b64.lua b/etc/b64.lua index de83578..ea157c4 100644 --- a/etc/b64.lua +++ b/etc/b64.lua @@ -1,13 +1,12 @@ +local source = ltn12.source.file(io.stdin) +local sink = ltn12.sink.file(io.stdout) local convert if arg and arg[1] == '-d' then - convert = socket.mime.decode("base64") + convert = mime.decode("base64") else - local base64 = socket.mime.encode("base64") - local wrap = socket.mime.wrap() - convert = socket.mime.chain(base64, wrap) -end -while 1 do - local chunk = io.read(4096) - io.write(convert(chunk)) - if not chunk then break end + local base64 = mime.encode("base64") + local wrap = mime.wrap() + convert = ltn12.filter.chain(base64, wrap) end +source = ltn12.source.chain(source, convert) +ltn12.pump(source, sink) diff --git a/etc/get.lua b/etc/get.lua index d6760b8..0306b54 100644 --- a/etc/get.lua +++ b/etc/get.lua @@ -80,39 +80,31 @@ function stats(size) end end --- downloads a file using the ftp protocol -function getbyftp(url, file) - local save = socket.callback.receive.file(file or io.stdout) - if file then - save = socket.callback.receive.chain(stats(gethttpsize(url)), save) - end - local err = socket.ftp.get_cb { - url = url, - content_cb = save, - type = "i" - } - if err then print(err) end +-- determines the size of a http file +function gethttpsize(url) + local respt = socket.http.request {method = "HEAD", url = url} + if respt.code == 200 then + return tonumber(respt.headers["content-length"]) + end end -- downloads a file using the http protocol function getbyhttp(url, file) - local save = socket.callback.receive.file(file or io.stdout) - if file then - save = socket.callback.receive.chain(stats(gethttpsize(url)), save) - end - local response = socket.http.request_cb({url = url}, {body_cb = save}) - if response.code ~= 200 then print(response.status or response.error) end + local save = ltn12.sink.file(file or io.stdout) + -- only print feedback if output is not stdout + if file then save = ltn12.sink.chain(stats(gethttpsize(url)), save) end + local respt = socket.http.request_cb({url = url, sink = save}) + if respt.code ~= 200 then print(respt.status or respt.error) end end --- determines the size of a http file -function gethttpsize(url) - local response = socket.http.request { - method = "HEAD", - url = url - } - if response.code == 200 then - return tonumber(response.headers["content-length"]) - end +-- downloads a file using the ftp protocol +function getbyftp(url, file) + local save = ltn12.sink.file(file or io.stdout) + -- only print feedback if output is not stdout + -- and we don't know how big the file is + if file then save = ltn12.sink.chain(stats(), save) end + local ret, err = socket.ftp.get_cb {url = url, sink = save, type = "i"} + if err then print(err) end end -- determines the scheme @@ -130,7 +122,6 @@ function get(url, name) if scheme == "ftp" then getbyftp(url, fout) elseif scheme == "http" then getbyhttp(url, fout) else print("unknown scheme" .. scheme) end - if name then fout:close() end end -- main program diff --git a/src/auxiliar.c b/src/auxiliar.c index 812d7fc..fe21d08 100644 --- a/src/auxiliar.c +++ b/src/auxiliar.c @@ -158,14 +158,3 @@ void *aux_getclassudata(lua_State *L, const char *classname, int objidx) { return luaL_checkudata(L, objidx, classname); } - -/*-------------------------------------------------------------------------*\ -* Accept "false" as nil -\*-------------------------------------------------------------------------*/ -const char *aux_optlstring(lua_State *L, int n, const char *v, size_t *l) -{ - if (lua_isnil(L, n) || (lua_isboolean(L, n) && !lua_toboolean(L, n))) { - *l = 0; - return NULL; - } else return luaL_optlstring(L, n, v, l); -} diff --git a/src/auxiliar.h b/src/auxiliar.h index ac62ecd..bc45182 100644 --- a/src/auxiliar.h +++ b/src/auxiliar.h @@ -49,6 +49,5 @@ void *aux_checkgroup(lua_State *L, const char *groupname, int objidx); void *aux_getclassudata(lua_State *L, const char *groupname, int objidx); void *aux_getgroupudata(lua_State *L, const char *groupname, int objidx); int aux_checkboolean(lua_State *L, int objidx); -const char *aux_optlstring(lua_State *L, int n, const char *v, size_t *l); #endif /* AUX_H */ diff --git a/src/ftp.lua b/src/ftp.lua index e596416..18dab6d 100644 --- a/src/ftp.lua +++ b/src/ftp.lua @@ -5,62 +5,29 @@ -- Conforming to: RFC 959, LTN7 -- RCS ID: $Id$ ----------------------------------------------------------------------------- - -local Public, Private = {}, {} -local socket = _G[LUASOCKET_LIBNAME] -- get LuaSocket namespace -socket.ftp = Public -- create ftp sub namespace +-- make sure LuaSocket is loaded +if not LUASOCKET_LIBNAME then error('module requires LuaSocket') end +-- get LuaSocket namespace +local socket = _G[LUASOCKET_LIBNAME] +if not socket then error('module requires LuaSocket') end +-- create namespace inside LuaSocket namespace +socket.ftp = socket.ftp or {} +-- make all module globals fall into namespace +setmetatable(socket.ftp, { __index = _G }) +setfenv(1, socket.ftp) ----------------------------------------------------------------------------- -- Program constants ----------------------------------------------------------------------------- -- timeout in seconds before the program gives up on a connection -Public.TIMEOUT = 60 +TIMEOUT = 60 -- default port for ftp service -Public.PORT = 21 +PORT = 21 -- this is the default anonymous password. used when no password is -- provided in url. should be changed to your e-mail. -Public.EMAIL = "anonymous@anonymous.org" +EMAIL = "anonymous@anonymous.org" -- block size used in transfers -Public.BLOCKSIZE = 8192 - ------------------------------------------------------------------------------ --- Tries to get a pattern from the server and closes socket on error --- sock: socket connected to the server --- pattern: pattern to receive --- Returns --- received pattern on success --- nil followed by error message on error ------------------------------------------------------------------------------ -function Private.try_receive(sock, pattern) - local data, err = sock:receive(pattern) - if not data then sock:close() end - return data, err -end - ------------------------------------------------------------------------------ --- Tries to send data to the server and closes socket on error --- sock: socket connected to the server --- data: data to send --- Returns --- err: error message if any, nil if successfull ------------------------------------------------------------------------------ -function Private.try_send(sock, data) - local sent, err = sock:send(data) - if not sent then sock:close() end - return err -end - ------------------------------------------------------------------------------ --- Tries to send DOS mode lines. Closes socket on error. --- Input --- sock: server socket --- line: string to be sent --- Returns --- err: message in case of error, nil if successfull ------------------------------------------------------------------------------ -function Private.try_sendline(sock, line) - return Private.try_send(sock, line .. "\r\n") -end +BLOCKSIZE = 2048 ----------------------------------------------------------------------------- -- Gets ip and port for data connection from PASV answer @@ -70,7 +37,7 @@ end -- ip: string containing ip for data connection -- port: port for data connection ----------------------------------------------------------------------------- -function Private.get_pasv(pasv) +local function get_pasv(pasv) local a, b, c, d, p1, p2, _ local ip, port _,_, a, b, c, d, p1, p2 = @@ -81,88 +48,6 @@ function Private.get_pasv(pasv) return ip, port end ------------------------------------------------------------------------------ --- Sends a FTP command through socket --- Input --- control: control connection socket --- cmd: command --- arg: command argument if any --- Returns --- error message in case of error, nil otherwise ------------------------------------------------------------------------------ -function Private.send_command(control, cmd, arg) - local line - if arg then line = cmd .. " " .. arg - else line = cmd end - return Private.try_sendline(control, line) -end - ------------------------------------------------------------------------------ --- Gets FTP command answer, unfolding if neccessary --- Input --- control: control connection socket --- Returns --- answer: whole server reply, nil if error --- code: answer status code or error message ------------------------------------------------------------------------------ -function Private.get_answer(control) - local code, lastcode, sep, _ - local line, err = Private.try_receive(control) - local answer = line - if err then return nil, err end - _,_, code, sep = string.find(line, "^(%d%d%d)(.)") - if not code or not sep then return nil, answer end - if sep == "-" then -- answer is multiline - repeat - line, err = Private.try_receive(control) - if err then return nil, err end - _,_, lastcode, sep = string.find(line, "^(%d%d%d)(.)") - answer = answer .. "\n" .. line - until code == lastcode and sep == " " -- answer ends with same code - end - return answer, tonumber(code) -end - ------------------------------------------------------------------------------ --- Checks if a message return is correct. Closes control connection if not. --- Input --- control: control connection socket --- success: table with successfull reply status code --- Returns --- code: reply code or nil in case of error --- answer: server complete answer or system error message ------------------------------------------------------------------------------ -function Private.check_answer(control, success) - local answer, code = Private.get_answer(control) - if not answer then return nil, code end - if type(success) ~= "table" then success = {success} end - for _, s in ipairs(success) do - if code == s then - return code, answer - end - end - control:close() - return nil, answer -end - ------------------------------------------------------------------------------ --- Trys a command on control socked, in case of error, the control connection --- is closed. --- Input --- control: control connection socket --- cmd: command --- arg: command argument or nil if no argument --- success: table with successfull reply status code --- Returns --- code: reply code or nil in case of error --- answer: server complete answer or system error message ------------------------------------------------------------------------------ -function Private.command(control, cmd, arg, success) - local err = Private.send_command(control, cmd, arg) - if err then return nil, err end - return Private.check_answer(control, success) -end - ----------------------------------------------------------------------------- -- Check server greeting -- Input @@ -171,10 +56,10 @@ end -- code: nil if error -- answer: server answer or error message ----------------------------------------------------------------------------- -function Private.greet(control) - local code, answer = Private.check_answer(control, {120, 220}) +local function greet(control) + local code, answer = check_answer(control, {120, 220}) if code == 120 then -- please try again, somewhat busy now... - return Private.check_answer(control, {220}) + return check_answer(control, {220}) end return code, answer end @@ -189,10 +74,10 @@ end -- code: nil if error -- answer: server answer or error message ----------------------------------------------------------------------------- -function Private.login(control, user, password) - local code, answer = Private.command(control, "user", user, {230, 331}) +local function login(control, user, password) + local code, answer = command(control, "user", user, {230, 331}) if code == 331 and password then -- need pass and we have pass - return Private.command(control, "pass", password, {230, 202}) + return command(control, "pass", password, {230, 202}) end return code, answer end @@ -206,9 +91,7 @@ end -- code: nil if error -- answer: server answer or error message ----------------------------------------------------------------------------- -function Private.cwd(control, path) - if path then return Private.command(control, "cwd", path, {250}) - else return 250, nil end +local function cwd(control, path) end ----------------------------------------------------------------------------- @@ -219,18 +102,18 @@ end -- server: server socket bound to local address, nil if error -- answer: error message if any ----------------------------------------------------------------------------- -function Private.port(control) +local function port(control) local code, answer local server, ctl_ip ctl_ip, answer = control:getsockname() server, answer = socket.bind(ctl_ip, 0) - server:settimeout(Public.TIMEOUT) + server:settimeout(TIMEOUT) local ip, p, ph, pl ip, p = server:getsockname() pl = math.mod(p, 256) ph = (p - pl)/256 local arg = string.gsub(string.format("%s,%d,%d", ip, ph, pl), "%.", ",") - code, answer = Private.command(control, "port", arg, {200}) + code, answer = command(control, "port", arg, {200}) if not code then server:close() return nil, answer @@ -245,8 +128,8 @@ end -- code: nil if error -- answer: server answer or error message ----------------------------------------------------------------------------- -function Private.logout(control) - local code, answer = Private.command(control, "quit", nil, {221}) +local function logout(control) + local code, answer = command(control, "quit", nil, {221}) if code then control:close() end return code, answer end @@ -259,10 +142,10 @@ end -- Returns -- nil if successfull, or an error message in case of error ----------------------------------------------------------------------------- -function Private.receive_indirect(data, callback) +local function receive_indirect(data, callback) local chunk, err, res while not err do - chunk, err = Private.try_receive(data, Public.BLOCKSIZE) + chunk, err = try_receive(data, BLOCKSIZE) if err == "closed" then err = "done" end res = callback(chunk, err) if not res then break end @@ -280,16 +163,16 @@ end -- Returns -- err: error message in case of error, nil otherwise ----------------------------------------------------------------------------- -function Private.retrieve(control, server, name, is_directory, content_cb) +local function retrieve(control, server, name, is_directory, content_cb) local code, answer local data -- ask server for file or directory listing accordingly if is_directory then - code, answer = Private.cwd(control, name) + code, answer = cwd(control, name) if not code then return answer end - code, answer = Private.command(control, "nlst", nil, {150, 125}) + code, answer = command(control, "nlst", nil, {150, 125}) else - code, answer = Private.command(control, "retr", name, {150, 125}) + code, answer = command(control, "retr", name, {150, 125}) end if not code then return nil, answer end data, answer = server:accept() @@ -298,43 +181,14 @@ function Private.retrieve(control, server, name, is_directory, content_cb) control:close() return answer end - answer = Private.receive_indirect(data, content_cb) + answer = receive_indirect(data, content_cb) if answer then control:close() return answer end data:close() -- make sure file transfered ok - return Private.check_answer(control, {226, 250}) -end - ------------------------------------------------------------------------------ --- Sends data comming from a callback --- Input --- data: data connection --- send_cb: callback to produce file contents --- chunk, size: first callback return values --- Returns --- nil if successfull, or an error message in case of error ------------------------------------------------------------------------------ -function Private.send_indirect(data, send_cb, chunk, size) - local total, sent, err - total = 0 - while 1 do - if type(chunk) ~= "string" or type(size) ~= "number" then - data:close() - if not chunk and type(size) == "string" then return size - else return "invalid callback return" end - end - sent, err = data:send(chunk) - if err then - data:close() - return err - end - total = total + sent - if sent >= size then break end - chunk, size = send_cb() - end + return check_answer(control, {226, 250}) end ----------------------------------------------------------------------------- @@ -348,9 +202,9 @@ end -- code: return code, nil if error -- answer: server answer or error message ----------------------------------------------------------------------------- -function Private.store(control, server, file, send_cb) +local function store(control, server, file, send_cb) local data, err - local code, answer = Private.command(control, "stor", file, {150, 125}) + local code, answer = command(control, "stor", file, {150, 125}) if not code then control:close() return nil, answer @@ -363,7 +217,7 @@ function Private.store(control, server, file, send_cb) return nil, answer end -- send whole file - err = Private.send_indirect(data, send_cb, send_cb()) + err = send_indirect(data, send_cb, send_cb()) if err then control:close() return nil, err @@ -371,7 +225,7 @@ function Private.store(control, server, file, send_cb) -- close connection to inform that file transmission is complete data:close() -- check if file was received correctly - return Private.check_answer(control, {226, 250}) + return check_answer(control, {226, 250}) end ----------------------------------------------------------------------------- @@ -382,11 +236,11 @@ end -- Returns -- err: error message if any ----------------------------------------------------------------------------- -function Private.change_type(control, params) +local function change_type(control, params) local type, _ _, _, type = string.find(params or "", "type=(.)") if type == "a" or type == "i" then - local code, err = Private.command(control, "type", type, {200}) + local code, err = command(control, "type", type, {200}) if not code then return err end end end @@ -399,45 +253,42 @@ end -- control: control connection with server, or nil if error -- err: error message if any ----------------------------------------------------------------------------- -function Private.open(parsed) - -- start control connection - local control, err = socket.connect(parsed.host, parsed.port) +local function open(parsed) + local control, err = socket.tp.connect(parsed.host, parsed.port) if not control then return nil, err end - -- make sure we don't block forever - control:settimeout(Public.TIMEOUT) - -- check greeting - local code, answer = Private.greet(control) - if not code then return nil, answer end - -- try to log in - code, err = Private.login(control, parsed.user, parsed.password) - if not code then return nil, err - else return control end -end - ------------------------------------------------------------------------------ --- Closes the connection with the server --- Input --- control: control connection with server ------------------------------------------------------------------------------ -function Private.close(control) - -- disconnect - Private.logout(control) -end - ------------------------------------------------------------------------------ --- Changes to the directory pointed to by URL --- Input --- control: control connection with server --- segment: parsed URL path segments --- Returns --- err: error message if any ------------------------------------------------------------------------------ -function Private.change_dir(control, segment) - local n = table.getn(segment) - for i = 1, n-1 do - local code, answer = Private.cwd(control, segment[i]) - if not code then return answer end + local code, reply + -- greet + code, reply = control:check({120, 220}) + if code == 120 then -- busy, try again + code, reply = control:check(220) + end + -- authenticate + code, reply = control:command("user", user) + code, reply = control:check({230, 331}) + if code == 331 and password then -- need pass and we have pass + control:command("pass", password) + code, reply = control:check({230, 202}) + end + -- change directory + local segment = parse_path(parsed) + for i, v in ipairs(segment) do + code, reply = control:command("cwd") + code, reply = control:check(250) end + -- change type + local type = string.sub(params or "", 7, 7) + if type == "a" or type == "i" then + code, reply = control:command("type", type) + code, reply = control:check(200) + end +end + + return change_dir(control, segment) or + change_type(control, parsed.params) or + download(control, request, segment) or + close(control) +end + end ----------------------------------------------------------------------------- @@ -450,7 +301,7 @@ end -- Returns -- err: error message if any ----------------------------------------------------------------------------- -function Private.upload(control, request, segment) +local function upload(control, request, segment) local code, name, content_cb -- get remote file name name = segment[table.getn(segment)] @@ -460,10 +311,10 @@ function Private.upload(control, request, segment) end content_cb = request.content_cb -- setup passive connection - local server, answer = Private.port(control) + local server, answer = port(control) if not server then return answer end -- ask server to receive file - code, answer = Private.store(control, server, name, content_cb) + code, answer = store(control, server, name, content_cb) if not code then return answer end end @@ -477,7 +328,7 @@ end -- Returns -- err: error message if any ----------------------------------------------------------------------------- -function Private.download(control, request, segment) +local function download(control, request, segment) local code, name, is_directory, content_cb is_directory = segment.is_directory content_cb = request.content_cb @@ -488,10 +339,10 @@ function Private.download(control, request, segment) return "Invalid file path" end -- setup passive connection - local server, answer = Private.port(control) + local server, answer = port(control) if not server then return answer end -- ask server to send file or directory listing - code, answer = Private.retrieve(control, server, name, + code, answer = retrieve(control, server, name, is_directory, content_cb) if not code then return answer end end @@ -507,13 +358,12 @@ end -- Returns -- parsed: a table with parsed components ----------------------------------------------------------------------------- -function Private.parse_url(request) +local function parse_url(request) local parsed = socket.url.parse(request.url, { - host = "", user = "anonymous", port = 21, path = "/", - password = Public.EMAIL, + password = EMAIL, scheme = "ftp" }) -- explicit login information overrides that given by URL @@ -531,7 +381,7 @@ end -- Returns -- dirs: a table with parsed directory components ----------------------------------------------------------------------------- -function Private.parse_path(parsed_url) +local function parse_path(parsed_url) local segment = socket.url.parse_path(parsed_url.path) segment.is_directory = segment.is_directory or (parsed_url.params == "type=d") @@ -549,7 +399,7 @@ end -- Returns -- request: request table ----------------------------------------------------------------------------- -function Private.build_request(data) +local function build_request(data) local request = {} if type(data) == "table" then for i, v in data do request[i] = v end else request.url = data end @@ -568,18 +418,18 @@ end -- Returns -- err: error message if any ----------------------------------------------------------------------------- -function Public.get_cb(request) - local parsed = Private.parse_url(request) +function get_cb(request) + local parsed = parse_url(request) if parsed.scheme ~= "ftp" then return string.format("unknown scheme '%s'", parsed.scheme) end - local control, err = Private.open(parsed) + local control, err = open(parsed) if not control then return err end - local segment = Private.parse_path(parsed) - return Private.change_dir(control, segment) or - Private.change_type(control, parsed.params) or - Private.download(control, request, segment) or - Private.close(control) + local segment = parse_path(parsed) + return change_dir(control, segment) or + change_type(control, parsed.params) or + download(control, request, segment) or + close(control) end ----------------------------------------------------------------------------- @@ -594,18 +444,18 @@ end -- Returns -- err: error message if any ----------------------------------------------------------------------------- -function Public.put_cb(request) - local parsed = Private.parse_url(request) +function put_cb(request) + local parsed = parse_url(request) if parsed.scheme ~= "ftp" then return string.format("unknown scheme '%s'", parsed.scheme) end - local control, err = Private.open(parsed) + local control, err = open(parsed) if not control then return err end - local segment = Private.parse_path(parsed) - err = Private.change_dir(control, segment) or - Private.change_type(control, parsed.params) or - Private.upload(control, request, segment) or - Private.close(control) + local segment = parse_path(parsed) + err = change_dir(control, segment) or + change_type(control, parsed.params) or + upload(control, request, segment) or + close(control) if err then return nil, err else return 1 end end @@ -623,11 +473,11 @@ end -- Returns -- err: error message if any ----------------------------------------------------------------------------- -function Public.put(url_or_request, content) - local request = Private.build_request(url_or_request) +function put(url_or_request, content) + local request = build_request(url_or_request) request.content = request.content or content request.content_cb = socket.callback.send_string(request.content) - return Public.put_cb(request) + return put_cb(request) end ----------------------------------------------------------------------------- @@ -642,12 +492,12 @@ end -- data: file contents as a string -- err: error message in case of error, nil otherwise ----------------------------------------------------------------------------- -function Public.get(url_or_request) +function get(url_or_request) local concat = socket.concat.create() - local request = Private.build_request(url_or_request) + local request = build_request(url_or_request) request.content_cb = socket.callback.receive_concat(concat) - local err = Public.get_cb(request) + local err = get_cb(request) return concat:getresult(), err end -return ftp +return socket.ftp diff --git a/src/http.lua b/src/http.lua index 74c29ba..629bf65 100644 --- a/src/http.lua +++ b/src/http.lua @@ -10,12 +10,11 @@ if not LUASOCKET_LIBNAME then error('module requires LuaSocket') end -- get LuaSocket namespace local socket = _G[LUASOCKET_LIBNAME] if not socket then error('module requires LuaSocket') end --- create smtp namespace inside LuaSocket namespace -local http = socket.http or {} -socket.http = http --- make all module globals fall into smtp namespace -setmetatable(http, { __index = _G }) -setfenv(1, http) +-- create namespace inside LuaSocket namespace +socket.http = socket.http or {} +-- make all module globals fall into namespace +setmetatable(socket.http, { __index = _G }) +setfenv(1, socket.http) ----------------------------------------------------------------------------- -- Program constants @@ -27,7 +26,18 @@ PORT = 80 -- user agent field sent in request USERAGENT = socket.version -- block size used in transfers -BLOCKSIZE = 8192 +BLOCKSIZE = 2048 + +----------------------------------------------------------------------------- +-- Function return value selectors +----------------------------------------------------------------------------- +local function second(a, b) + return b +end + +local function third(a, b, c) + return c +end ----------------------------------------------------------------------------- -- Tries to get a pattern from the server and closes socket on error @@ -47,7 +57,7 @@ end ----------------------------------------------------------------------------- -- Tries to send data to the server and closes socket on error -- sock: socket connected to the server --- data: data to send +-- ...: data to send -- Returns -- err: error message if any, nil if successfull ----------------------------------------------------------------------------- @@ -68,11 +78,9 @@ end -- err: error message if any ----------------------------------------------------------------------------- local function receive_status(sock) - local line, err - line, err = try_receiving(sock) + local line, err = try_receiving(sock) if not err then - local code, _ - _, _, code = string.find(line, "HTTP/%d*%.%d* (%d%d%d)") + local code = third(string.find(line, "HTTP/%d*%.%d* (%d%d%d)")) return tonumber(code), line else return nil, nil, err end end @@ -121,7 +129,7 @@ local function receive_headers(sock, headers) end ----------------------------------------------------------------------------- --- Aborts a receive callback +-- Aborts a sink with an error message -- Input -- cb: callback function -- err: error message to pass to callback @@ -129,8 +137,8 @@ end -- callback return or if nil err ----------------------------------------------------------------------------- local function abort(cb, err) - local go, err_or_f = cb(nil, err) - return err_or_f or err + local go, cb_err = cb(nil, err) + return cb_err or err end ----------------------------------------------------------------------------- @@ -138,41 +146,36 @@ end -- Input -- sock: socket connected to the server -- headers: header set in which to include trailer headers --- receive_cb: function to receive chunks +-- sink: response message body sink -- Returns -- nil if successfull or an error message in case of error ----------------------------------------------------------------------------- -local function receive_body_bychunks(sock, headers, receive_cb) - local chunk, size, line, err, go, err_or_f, _ +local function receive_body_bychunks(sock, headers, sink) + local chunk, size, line, err, go while 1 do -- get chunk size, skip extention line, err = try_receiving(sock) - if err then return abort(receive_cb, err) end + if err then return abort(sink, err) end size = tonumber(string.gsub(line, ";.*", ""), 16) - if not size then return abort(receive_cb, "invalid chunk size") end + if not size then return abort(sink, "invalid chunk size") end -- was it the last chunk? if size <= 0 then break end -- get chunk chunk, err = try_receiving(sock, size) - if err then return abort(receive_cb, err) end + if err then return abort(sink, err) end -- pass chunk to callback - go, err_or_f = receive_cb(chunk) - -- see if callback needs to be replaced - receive_cb = err_or_f or receive_cb + go, err = sink(chunk) -- see if callback aborted - if not go then return err_or_f or "aborted by callback" end + if not go then return err or "aborted by callback" end -- skip CRLF on end of chunk - _, err = try_receiving(sock) - if err then return abort(receive_cb, err) end + err = second(try_receiving(sock)) + if err then return abort(sink, err) end end - -- the server should not send trailer headers because we didn't send a - -- header informing it we know how to deal with them. we do not risk - -- being caught unprepaired. - _, err = receive_headers(sock, headers) - if err then return abort(receive_cb, err) end + -- servers shouldn't send trailer headers, but who trusts them? + err = second(receive_headers(sock, headers)) + if err then return abort(sink, err) end -- let callback know we are done - _, err_or_f = receive_cb("") - return err_or_f + return second(sink(nil)) end ----------------------------------------------------------------------------- @@ -180,94 +183,84 @@ end -- Input -- sock: socket connected to the server -- length: message body length --- receive_cb: function to receive chunks +-- sink: response message body sink -- Returns -- nil if successfull or an error message in case of error ----------------------------------------------------------------------------- -local function receive_body_bylength(sock, length, receive_cb) +local function receive_body_bylength(sock, length, sink) while length > 0 do local size = math.min(BLOCKSIZE, length) local chunk, err = sock:receive(size) - local go, err_or_f = receive_cb(chunk) + local go, cb_err = sink(chunk) length = length - string.len(chunk) -- see if callback aborted - if not go then return err_or_f or "aborted by callback" end - -- see if callback needs to be replaced - receive_cb = err_or_f or receive_cb + if not go then return cb_err or "aborted by callback" end -- see if there was an error - if err and length > 0 then return abort(receive_cb, err) end + if err and length > 0 then return abort(sink, err) end end - local _, err_or_f = receive_cb("") - return err_or_f + return second(sink(nil)) end ----------------------------------------------------------------------------- --- Receives a message body by content-length +-- Receives a message body until the conection is closed -- Input -- sock: socket connected to the server --- receive_cb: function to receive chunks +-- sink: response message body sink -- Returns -- nil if successfull or an error message in case of error ----------------------------------------------------------------------------- -local function receive_body_untilclosed(sock, receive_cb) +local function receive_body_untilclosed(sock, sink) while 1 do local chunk, err = sock:receive(BLOCKSIZE) - local go, err_or_f = receive_cb(chunk) + local go, cb_err = sink(chunk) -- see if callback aborted - if not go then return err_or_f or "aborted by callback" end - -- see if callback needs to be replaced - receive_cb = err_or_f or receive_cb + if not go then return cb_err or "aborted by callback" end -- see if we are done - if err == "closed" then - if chunk ~= "" then - go, err_or_f = receive_cb("") - return err_or_f - end - end + if err == "closed" then return chunk and second(sink(nil)) end -- see if there was an error - if err then return abort(receive_cb, err) end + if err then return abort(sink, err) end end end ----------------------------------------------------------------------------- --- Receives HTTP response body +-- Receives the HTTP response body -- Input -- sock: socket connected to the server -- headers: response header fields --- receive_cb: function to receive chunks +-- sink: response message body sink -- Returns -- nil if successfull or an error message in case of error ----------------------------------------------------------------------------- -local function receive_body(sock, headers, receive_cb) +local function receive_body(sock, headers, sink) + -- make sure sink is not fancy + sink = ltn12.sink.simplify(sink) local te = headers["transfer-encoding"] if te and te ~= "identity" then -- get by chunked transfer-coding of message body - return receive_body_bychunks(sock, headers, receive_cb) + return receive_body_bychunks(sock, headers, sink) elseif tonumber(headers["content-length"]) then -- get by content-length local length = tonumber(headers["content-length"]) - return receive_body_bylength(sock, length, receive_cb) + return receive_body_bylength(sock, length, sink) else -- get it all until connection closes - return receive_body_untilclosed(sock, receive_cb) + return receive_body_untilclosed(sock, sink) end end ----------------------------------------------------------------------------- --- Sends data comming from a callback +-- Sends the HTTP request message body in chunks -- Input -- data: data connection --- send_cb: callback to produce file contents +-- source: request message body source -- Returns -- nil if successfull, or an error message in case of error ----------------------------------------------------------------------------- -local function send_body_bychunks(data, send_cb) +local function send_body_bychunks(data, source) while 1 do - local chunk, err_or_f = send_cb() + local chunk, cb_err = source() -- check if callback aborted - if not chunk then return err_or_f or "aborted by callback" end - -- check if callback should be replaced - send_cb = err_or_f or send_cb + if not chunk then return cb_err or "aborted by callback" end -- if we are done, send last-chunk if chunk == "" then return try_sending(data, "0\r\n\r\n") end -- else send middle chunk @@ -281,22 +274,18 @@ local function send_body_bychunks(data, send_cb) end ----------------------------------------------------------------------------- --- Sends data comming from a callback +-- Sends the HTTP request message body -- Input -- data: data connection --- send_cb: callback to produce body contents +-- source: request message body source -- Returns -- nil if successfull, or an error message in case of error ----------------------------------------------------------------------------- -local function send_body_bylength(data, send_cb) +local function send_body(data, source) while 1 do - local chunk, err_or_f = send_cb() - -- check if callback aborted - if not chunk then return err_or_f or "aborted by callback" end - -- check if callback should be replaced - send_cb = err_or_f or send_cb + local chunk, cb_err = source() -- check if callback is done - if chunk == "" then return end + if not chunk then return cb_err end -- send data local err = try_sending(data, chunk) if err then return err end @@ -304,10 +293,10 @@ local function send_body_bylength(data, send_cb) end ----------------------------------------------------------------------------- --- Sends mime headers +-- Sends request headers -- Input -- sock: server socket --- headers: table with mime headers to be sent +-- headers: table with headers to be sent -- Returns -- err: error message if any ----------------------------------------------------------------------------- @@ -330,27 +319,29 @@ end -- method: request method to be used -- uri: request uri -- headers: request headers to be sent --- body_cb: callback to send request message body +-- source: request message body source -- Returns -- err: nil in case of success, error message otherwise ----------------------------------------------------------------------------- -local function send_request(sock, method, uri, headers, body_cb) +local function send_request(sock, method, uri, headers, source) local chunk, size, done, err -- send request line err = try_sending(sock, method .. " " .. uri .. " HTTP/1.1\r\n") if err then return err end - if body_cb and not headers["content-length"] then + if source and not headers["content-length"] then headers["transfer-encoding"] = "chunked" end -- send request headers err = send_headers(sock, headers) if err then return err end -- send request message body, if any - if body_cb then - if not headers["content-length"] then - return send_body_bychunks(sock, body_cb) + if source then + -- make sure source is not fancy + source = ltn12.source.simplify(source) + if headers["content-length"] then + return send_body(sock, source) else - return send_body_bylength(sock, body_cb) + return send_body_bychunks(sock, source) end end end @@ -415,23 +406,23 @@ end -- Input -- reqt: a table with the original request information -- parsed: parsed request URL --- respt: a table with the server response information -- Returns -- respt: result of target authorization ----------------------------------------------------------------------------- -local function authorize(reqt, parsed, respt) +local function authorize(reqt, parsed) reqt.headers["authorization"] = "Basic " .. - (socket.mime.b64(parsed.user .. ":" .. parsed.password)) + (mime.b64(parsed.user .. ":" .. parsed.password)) local autht = { nredirects = reqt.nredirects, method = reqt.method, url = reqt.url, - body_cb = reqt.body_cb, + source = reqt.source, + sink = reqt.sink, headers = reqt.headers, timeout = reqt.timeout, proxy = reqt.proxy, } - return request_cb(autht, respt) + return request_cb(autht) end ----------------------------------------------------------------------------- @@ -443,8 +434,8 @@ end -- 1 if we should redirect, nil otherwise ----------------------------------------------------------------------------- local function should_redirect(reqt, respt) - return (reqt.redirect ~= false) and - (respt.code == 301 or respt.code == 302) and + return (reqt.redirect ~= false) and + (respt.code == 301 or respt.code == 302) and (reqt.method == "GET" or reqt.method == "HEAD") and not (reqt.nredirects and reqt.nredirects >= 5) end @@ -453,8 +444,7 @@ end -- Returns the result of a request following a server redirect message. -- Input -- reqt: a table with the original request information --- respt: a table with the following fields: --- body_cb: response method body receive-callback +-- respt: response table of previous attempt -- Returns -- respt: result of target redirection ----------------------------------------------------------------------------- @@ -467,12 +457,13 @@ local function redirect(reqt, respt) -- the RFC says the redirect URL has to be absolute, but some -- servers do not respect that url = socket.url.absolute(reqt.url, respt.headers["location"]), - body_cb = reqt.body_cb, + source = reqt.source, + sink = reqt.sink, headers = reqt.headers, timeout = reqt.timeout, proxy = reqt.proxy } - respt = request_cb(redirt, respt) + respt = request_cb(redirt) -- we pass the location header as a clue we tried to redirect if respt.headers then respt.headers.location = redirt.url end return respt @@ -562,10 +553,9 @@ end -- url: target uniform resource locator -- user, password: authentication information -- headers: request headers to send, or nil if none --- body_cb: request message body send-callback, or nil if none +-- source: request message body source, or nil if none +-- sink: response message body sink -- redirect: should we refrain from following a server redirect message? --- respt: a table with the following fields: --- body_cb: response method body receive-callback -- Returns -- respt: a table with the following fields: -- headers: response header fields received, or nil if failed @@ -573,7 +563,7 @@ end -- code: server status code, or nil if failed -- error: error message, or nil if successfull ----------------------------------------------------------------------------- -function request_cb(reqt, respt) +function request_cb(reqt) local sock, ret local parsed = socket.url.parse(reqt.url, { host = "", @@ -581,6 +571,7 @@ function request_cb(reqt, respt) path ="/", scheme = "http" }) + local respt = {} if parsed.scheme ~= "http" then respt.error = string.format("unknown scheme '%s'", parsed.scheme) return respt @@ -597,7 +588,7 @@ function request_cb(reqt, respt) if not sock then return respt end -- send request message respt.error = send_request(sock, reqt.method, - request_uri(reqt, parsed), reqt.headers, reqt.body_cb) + request_uri(reqt, parsed), reqt.headers, reqt.source) if respt.error then sock:close() return respt @@ -619,18 +610,18 @@ function request_cb(reqt, respt) -- decide what to do based on request and response parameters if should_redirect(reqt, respt) then -- drop the body - receive_body(sock, respt.headers, function (c, e) return 1 end) + receive_body(sock, respt.headers, ltn12.sink.null()) -- we are done with this connection sock:close() return redirect(reqt, respt) elseif should_authorize(reqt, parsed, respt) then -- drop the body - receive_body(sock, respt.headers, function (c, e) return 1 end) + receive_body(sock, respt.headers, ltn12.sink.null()) -- we are done with this connection sock:close() return authorize(reqt, parsed, respt) elseif should_receive_body(reqt, respt) then - respt.error = receive_body(sock, respt.headers, respt.body_cb) + respt.error = receive_body(sock, respt.headers, reqt.sink) if respt.error then return respt end sock:close() return respt @@ -658,13 +649,11 @@ end -- error: error message if any ----------------------------------------------------------------------------- function request(reqt) - local respt = {} - reqt.body_cb = socket.callback.send.string(reqt.body) - local concat = socket.concat.create() - respt.body_cb = socket.callback.receive.concat(concat) - respt = request_cb(reqt, respt) - respt.body = concat:getresult() - respt.body_cb = nil + reqt.source = reqt.body and ltn12.source.string(reqt.body) + local t = {} + reqt.sink = ltn12.sink.table(t) + local respt = request_cb(reqt) + if table.getn(t) > 0 then respt.body = table.concat(t) end return respt end @@ -713,4 +702,4 @@ function post(url_or_request, body) return respt.body, respt.headers, respt.code, respt.error end -return http +return socket.http diff --git a/src/ltn12.lua b/src/ltn12.lua index 548588a..de7103d 100644 --- a/src/ltn12.lua +++ b/src/ltn12.lua @@ -1,6 +1,6 @@ --- create code namespace inside LuaSocket namespace +-- create module namespace ltn12 = ltn12 or {} --- make all module globals fall into mime namespace +-- make all globals fall into ltn12 namespace setmetatable(ltn12, { __index = _G }) setfenv(1, ltn12) @@ -12,6 +12,14 @@ sink = {} -- 2048 seems to be better in windows... BLOCKSIZE = 2048 +local function second(a, b) + return b +end + +local function skip(a, b, c) + return b, c +end + -- returns a high level filter that cycles a cycles a low-level filter function filter.cycle(low, ctx, extra) return function(chunk) @@ -24,9 +32,7 @@ end -- chains two filters together local function chain2(f1, f2) return function(chunk) - local ret = f2(f1(chunk)) - if chunk then return ret - else return ret .. f2() end + return f2(f1(chunk)) end end @@ -83,7 +89,6 @@ end -- creates rewindable source function source.rewind(src) local t = {} - src = source.simplify(src) return function(chunk) if not chunk then chunk = table.remove(t) @@ -97,13 +102,38 @@ end -- chains a source with a filter function source.chain(src, f) - src = source.simplify(src) - local chain = function() - local chunk, err = src() - if not chunk then return f(nil), source.empty(err) - else return f(chunk) end + local co = coroutine.create(function() + while true do + local chunk, err = src() + local filtered = f(chunk) + local done = chunk and "" + while true do + coroutine.yield(filtered) + if filtered == done then break end + filtered = f(done) + end + if not chunk then return nil, err end + end + end) + return function() + return skip(coroutine.resume(co)) end - return source.simplify(chain) +end + +-- creates a source that produces contents of several files one after the +-- other, as if they were concatenated +function source.cat(...) + local co = coroutine.create(function() + local i = 1 + while i <= table.getn(arg) do + local chunk = arg[i]:read(2048) + if chunk then coroutine.yield(chunk) + else i = i + 1 end + end + end) + return source.simplify(function() + return second(coroutine.resume(co)) + end) end -- creates a sink that stores into a table @@ -150,22 +180,25 @@ end -- chains a sink with a filter function sink.chain(f, snk) - snk = sink.simplify(snk) return function(chunk, err) - local r, e = snk(f(chunk)) - if not r then return nil, e end - if not chunk then return snk(nil, err) end - return 1 + local filtered = f(chunk) + local done = chunk and "" + while true do + local ret, snkerr = snk(filtered, err) + if not ret then return nil, snkerr end + if filtered == done then return 1 end + filtered = f(done) + end end end -- pumps all data from a source to a sink function pump(src, snk) - snk = sink.simplify(snk) - for chunk, src_err in source.simplify(src) do + while true do + local chunk, src_err = src() local ret, snk_err = snk(chunk, src_err) - if not chunk or not ret then - return not src_err and not snk_err, src_err or snk_err + if not chunk or not ret then + return not src_err and not snk_err, src_err or snk_err end end end diff --git a/src/luasocket.c b/src/luasocket.c index 47696cb..5b19696 100644 --- a/src/luasocket.c +++ b/src/luasocket.c @@ -74,22 +74,16 @@ static int mod_open(lua_State *L, const luaL_reg *mod) #ifdef LUASOCKET_COMPILED #include "ltn12.lch" #include "auxiliar.lch" -#include "concat.lch" #include "url.lch" -#include "callback.lch" #include "mime.lch" #include "smtp.lch" -#include "ftp.lch" #include "http.lch" #else lua_dofile(L, "ltn12.lua"); lua_dofile(L, "auxiliar.lua"); - lua_dofile(L, "concat.lua"); lua_dofile(L, "url.lua"); - lua_dofile(L, "callback.lua"); lua_dofile(L, "mime.lua"); lua_dofile(L, "smtp.lua"); - lua_dofile(L, "ftp.lua"); lua_dofile(L, "http.lua"); #endif return 0; diff --git a/src/mime.c b/src/mime.c index 1a8bff4..77f3ae1 100644 --- a/src/mime.c +++ b/src/mime.c @@ -20,8 +20,8 @@ #define SP 0x20 typedef unsigned char UC; -static const char CRLF[2] = {CR, LF}; -static const char EQCRLF[3] = {'=', CR, LF}; +static const char CRLF[] = {CR, LF, 0}; +static const char EQCRLF[] = {'=', CR, LF, 0}; /*=========================================================================*\ * Internal function prototypes. @@ -95,7 +95,7 @@ int mime_open(lua_State *L) * Global Lua functions \*=========================================================================*/ /*-------------------------------------------------------------------------*\ -* Incrementaly breaks a string into lines +* Incrementaly breaks a string into lines. The string can have CRLF breaks. * A, n = wrp(l, B, length) * A is a copy of B, broken into lines of at most 'length' bytes. * 'l' is how many bytes are left for the first line of B. @@ -109,6 +109,15 @@ static int mime_global_wrp(lua_State *L) const UC *last = input + size; int length = (int) luaL_optnumber(L, 3, 76); luaL_Buffer buffer; + /* end of input black-hole */ + if (!input) { + /* if last line has not been terminated, add a line break */ + if (left < length) lua_pushstring(L, CRLF); + /* otherwise, we are done */ + else lua_pushnil(L); + lua_pushnumber(L, length); + return 2; + } luaL_buffinit(L, &buffer); while (input < last) { switch (*input) { @@ -129,11 +138,6 @@ static int mime_global_wrp(lua_State *L) } input++; } - /* if in last chunk and last line wasn't terminated, add a line-break */ - if (!input && left < length) { - luaL_addstring(&buffer, CRLF); - left = length; - } luaL_pushresult(&buffer); lua_pushnumber(L, left); return 2; @@ -200,7 +204,6 @@ static size_t b64pad(const UC *input, size_t size, code[0] = b64base[value]; luaL_addlstring(buffer, (char *) code, 4); break; - case 0: /* fall through */ default: break; } @@ -250,19 +253,31 @@ static int mime_global_b64(lua_State *L) { UC atom[3]; size_t isize = 0, asize = 0; - const UC *input = (UC *) luaL_checklstring(L, 1, &isize); + const UC *input = (UC *) luaL_optlstring(L, 1, NULL, &isize); const UC *last = input + isize; luaL_Buffer buffer; + /* end-of-input blackhole */ + if (!input) { + lua_pushnil(L); + lua_pushnil(L); + return 2; + } + /* process first part of the input */ luaL_buffinit(L, &buffer); while (input < last) asize = b64encode(*input++, atom, asize, &buffer); input = (UC *) luaL_optlstring(L, 2, NULL, &isize); - if (input) { - last = input + isize; - while (input < last) - asize = b64encode(*input++, atom, asize, &buffer); - } else + /* if second part is nil, we are done */ + if (!input) { asize = b64pad(atom, asize, &buffer); + luaL_pushresult(&buffer); + lua_pushnil(L); + return 2; + } + /* otherwise process the second part */ + last = input + isize; + while (input < last) + asize = b64encode(*input++, atom, asize, &buffer); luaL_pushresult(&buffer); lua_pushlstring(L, (char *) atom, asize); return 2; @@ -278,20 +293,30 @@ static int mime_global_unb64(lua_State *L) { UC atom[4]; size_t isize = 0, asize = 0; - const UC *input = (UC *) luaL_checklstring(L, 1, &isize); + const UC *input = (UC *) luaL_optlstring(L, 1, NULL, &isize); const UC *last = input + isize; luaL_Buffer buffer; + /* end-of-input blackhole */ + if (!input) { + lua_pushnil(L); + lua_pushnil(L); + return 2; + } + /* process first part of the input */ luaL_buffinit(L, &buffer); while (input < last) asize = b64decode(*input++, atom, asize, &buffer); input = (UC *) luaL_optlstring(L, 2, NULL, &isize); - if (input) { - last = input + isize; - while (input < last) - asize = b64decode(*input++, atom, asize, &buffer); - } - /* if !input we are done. if atom > 0, the remaning is invalid. we just - * return it undecoded. */ + /* if second is nil, we are done */ + if (!input) { + luaL_pushresult(&buffer); + lua_pushnil(L); + return 2; + } + /* otherwise, process the rest of the input */ + last = input + isize; + while (input < last) + asize = b64decode(*input++, atom, asize, &buffer); luaL_pushresult(&buffer); lua_pushlstring(L, (char *) atom, asize); return 2; @@ -425,16 +450,27 @@ static int mime_global_qp(lua_State *L) const UC *last = input + isize; const char *marker = luaL_optstring(L, 3, CRLF); luaL_Buffer buffer; + /* end-of-input blackhole */ + if (!input) { + lua_pushnil(L); + lua_pushnil(L); + return 2; + } + /* process first part of input */ luaL_buffinit(L, &buffer); while (input < last) asize = qpencode(*input++, atom, asize, marker, &buffer); input = (UC *) luaL_optlstring(L, 2, NULL, &isize); - if (input) { - last = input + isize; - while (input < last) - asize = qpencode(*input++, atom, asize, marker, &buffer); - } else + /* if second part is nil, we are done */ + if (!input) { asize = qppad(atom, asize, &buffer); + luaL_pushresult(&buffer); + lua_pushnil(L); + } + /* otherwise process rest of input */ + last = input + isize; + while (input < last) + asize = qpencode(*input++, atom, asize, marker, &buffer); luaL_pushresult(&buffer); lua_pushlstring(L, (char *) atom, asize); return 2; @@ -487,21 +523,32 @@ static size_t qpdecode(UC c, UC *input, size_t size, \*-------------------------------------------------------------------------*/ static int mime_global_unqp(lua_State *L) { - size_t asize = 0, isize = 0; UC atom[3]; const UC *input = (UC *) luaL_optlstring(L, 1, NULL, &isize); const UC *last = input + isize; luaL_Buffer buffer; + /* end-of-input blackhole */ + if (!input) { + lua_pushnil(L); + lua_pushnil(L); + return 2; + } + /* process first part of input */ luaL_buffinit(L, &buffer); while (input < last) asize = qpdecode(*input++, atom, asize, &buffer); input = (UC *) luaL_optlstring(L, 2, NULL, &isize); - if (input) { - last = input + isize; - while (input < last) - asize = qpdecode(*input++, atom, asize, &buffer); + /* if second part is nil, we are done */ + if (!input) { + luaL_pushresult(&buffer); + lua_pushnil(L); + return 2; } + /* otherwise process rest of input */ + last = input + isize; + while (input < last) + asize = qpdecode(*input++, atom, asize, &buffer); luaL_pushresult(&buffer); lua_pushlstring(L, (char *) atom, asize); return 2; @@ -524,6 +571,14 @@ static int mime_global_qpwrp(lua_State *L) const UC *last = input + size; int length = (int) luaL_optnumber(L, 3, 76); luaL_Buffer buffer; + /* end-of-input blackhole */ + if (!input) { + if (left < length) lua_pushstring(L, EQCRLF); + else lua_pushnil(L); + lua_pushnumber(L, length); + return 2; + } + /* process all input */ luaL_buffinit(L, &buffer); while (input < last) { switch (*input) { @@ -552,11 +607,6 @@ static int mime_global_qpwrp(lua_State *L) } input++; } - /* if in last chunk and last line wasn't terminated, add a soft-break */ - if (!input && left < length) { - luaL_addstring(&buffer, EQCRLF); - left = length; - } luaL_pushresult(&buffer); lua_pushnumber(L, left); return 2; @@ -609,13 +659,16 @@ static int mime_global_eol(lua_State *L) const char *marker = luaL_optstring(L, 3, CRLF); luaL_Buffer buffer; luaL_buffinit(L, &buffer); - while (input < last) - ctx = eolprocess(*input++, ctx, marker, &buffer); /* if the last character was a candidate, we output a new line */ if (!input) { - if (eolcandidate(ctx)) luaL_addstring(&buffer, marker); - ctx = 0; + if (eolcandidate(ctx)) lua_pushstring(L, marker); + else lua_pushnil(L); + lua_pushnumber(L, 0); + return 2; } + /* process all input */ + while (input < last) + ctx = eolprocess(*input++, ctx, marker, &buffer); luaL_pushresult(&buffer); lua_pushnumber(L, ctx); return 2; diff --git a/src/mime.lua b/src/mime.lua index 4df0388..6db832d 100644 --- a/src/mime.lua +++ b/src/mime.lua @@ -1,4 +1,4 @@ -if not ltn12 then error('This module requires LTN12') end +if not ltn12 then error('Requires LTN12 module') end -- create mime namespace mime = mime or {} -- make all module globals fall into mime namespace diff --git a/src/smtp.lua b/src/smtp.lua index 8b65e44..6b02d14 100644 --- a/src/smtp.lua +++ b/src/smtp.lua @@ -19,7 +19,7 @@ DOMAIN = os.getenv("SERVER_NAME") or "localhost" SERVER = "localhost" function stuff() - return socket.cicle(dot, 2) + return ltn12.filter.cycle(dot, 2) end -- tries to get a pattern from the server and closes socket on error diff --git a/src/tp.lua b/src/tp.lua new file mode 100644 index 0000000..d8dabc0 --- /dev/null +++ b/src/tp.lua @@ -0,0 +1,111 @@ +----------------------------------------------------------------------------- +-- Unified SMTP/FTP subsystem +-- LuaSocket toolkit. +-- Author: Diego Nehab +-- Conforming to: RFC 2616, LTN7 +-- RCS ID: $Id$ +----------------------------------------------------------------------------- +-- make sure LuaSocket is loaded +if not LUASOCKET_LIBNAME then error('module requires LuaSocket') end +-- get LuaSocket namespace +local socket = _G[LUASOCKET_LIBNAME] +if not socket then error('module requires LuaSocket') end +-- create namespace inside LuaSocket namespace +socket.tp = socket.tp or {} +-- make all module globals fall into namespace +setmetatable(socket.tp, { __index = _G }) +setfenv(1, socket.tp) + +TIMEOUT = 60 + +-- tries to get a pattern from the server and closes socket on error +local function try_receiving(sock, pattern) + local data, message = sock:receive(pattern) + if not data then sock:close() end + return data, message +end + +-- tries to send data to server and closes socket on error +local function try_sending(sock, data) + local sent, message = sock:send(data) + if not sent then sock:close() end + return sent, message +end + +-- gets server reply +local function get_reply(sock) + local code, current, separator, _ + local line, message = try_receiving(sock) + local reply = line + if message then return nil, message end + _, _, code, separator = string.find(line, "^(%d%d%d)(.?)") + if not code then return nil, "invalid server reply" end + if separator == "-" then -- reply is multiline + repeat + line, message = try_receiving(sock) + if message then return nil, message end + _,_, current, separator = string.find(line, "^(%d%d%d)(.)") + if not current or not separator then + return nil, "invalid server reply" + end + reply = reply .. "\n" .. line + -- reply ends with same code + until code == current and separator == " " + end + return code, reply +end + +-- metatable for sock object +local metatable = { __index = {} } + +-- execute the "check" instr +function metatable.__index:check(ok) + local code, reply = get_reply(self.sock) + if not code then return nil, reply end + if type(ok) ~= "function" then + if type(ok) ~= "table" then ok = {ok} end + for i, v in ipairs(ok) do + if string.find(code, v) then return code, reply end + end + return nil, reply + else return ok(code, reply) end +end + +function metatable.__index:cmdchk(cmd, arg, ok) + local code, err = self:command(cmd, arg) + if not code then return nil, err end + return self:check(ok) +end + +-- execute the "command" instr +function metatable.__index:command(cmd, arg) + if arg then return try_sending(self.sock, cmd .. " " .. arg.. "\r\n") + return try_sending(self.sock, cmd .. "\r\n") end +end + +function metatable.__index:sink(snk, pat) + local chunk, err = sock:receive(pat) + return snk(chunk, err) +end + +function metatable.__index:source(src, instr) + while true do + local chunk, err = src() + if not chunk then return not err, err end + local ret, err = try_sending(self.sock, chunk) + if not ret then return nil, err end + end +end + +-- closes the underlying sock +function metatable.__index:close() + self.sock:close() +end + +-- connect with server and return sock object +function connect(host, port) + local sock, err = socket.connect(host, port) + if not sock then return nil, message end + sock:settimeout(TIMEOUT) + return setmetatable({sock = sock}, metatable) +end diff --git a/src/wsocket.c b/src/wsocket.c index 2993c35..af3f8d8 100644 --- a/src/wsocket.c +++ b/src/wsocket.c @@ -269,7 +269,7 @@ int sock_recv(p_sock ps, char *data, size_t count, size_t *got, int timeout) fd_set fds; int ret; *got = 0; - if (taken == 0) return IO_CLOSED; + if (taken == 0 || WSAGetLastError() != WSAEWOULDBLOCK) return IO_CLOSED; FD_ZERO(&fds); FD_SET(sock, &fds); ret = sock_select(0, &fds, NULL, NULL, timeout); @@ -295,7 +295,7 @@ int sock_recvfrom(p_sock ps, char *data, size_t count, size_t *got, fd_set fds; int ret; *got = 0; - if (taken == 0) return IO_CLOSED; + if (taken == 0 || WSAGetLastError() != WSAEWOULDBLOCK) return IO_CLOSED; FD_ZERO(&fds); FD_SET(sock, &fds); ret = sock_select(0, &fds, NULL, NULL, timeout); diff --git a/test/httptest.lua b/test/httptest.lua index c9a74a8..04c0ed0 100644 --- a/test/httptest.lua +++ b/test/httptest.lua @@ -12,7 +12,7 @@ socket.http.TIMEOUT = 5 local t = socket.time() -host = host or "diego-interface2.student.dyn.CS.Princeton.EDU" +host = host or "diego.student.princeton.edu" proxy = proxy or "http://localhost:3128" prefix = prefix or "/luasocket-test" cgiprefix = cgiprefix or "/luasocket-test-cgi" @@ -71,8 +71,8 @@ local check_request = function(request, expect, ignore) check_result(response, expect, ignore) end -local check_request_cb = function(request, response, expect, ignore) - local response = socket.http.request_cb(request, response) +local check_request_cb = function(request, expect, ignore) + local response = socket.http.request_cb(request) check_result(response, expect, ignore) end @@ -83,7 +83,7 @@ local back, h, c, e = socket.http.get("http://" .. host .. forth) if not back then fail(e) end back = socket.url.parse(back) if similar(back.query, "this+is+the+query+string") then print("ok") -else fail() end +else fail(back.query) end ------------------------------------------------------------------------ io.write("testing query string correctness: ") @@ -168,31 +168,28 @@ back = socket.http.post("http://" .. host .. cgiprefix .. "/cat", index) check(back == index) ------------------------------------------------------------------------ -io.write("testing send.file and receive.file callbacks: ") +io.write("testing ltn12.(sink|source).file: ") request = { url = "http://" .. host .. cgiprefix .. "/cat", method = "POST", - body_cb = socket.callback.send.file(io.open(index_file, "r")), + source = ltn12.source.file(io.open(index_file, "r")), + sink = ltn12.sink.file(io.open(index_file .. "-back", "w")), headers = { ["content-length"] = string.len(index) } } -response = { - body_cb = socket.callback.receive.file(io.open(index_file .. "-back", "w")) -} expect = { code = 200 } ignore = { - body_cb = 1, status = 1, headers = 1 } -check_request_cb(request, response, expect, ignore) +check_request_cb(request, expect, ignore) back = readfile(index_file .. "-back") check(back == index) os.remove(index_file .. "-back") ------------------------------------------------------------------------ -io.write("testing send.chain and receive.chain callbacks: ") +io.write("testing ltn12.(sink|source).chain and mime.(encode|decode): ") local function b64length(len) local a = math.ceil(len/3)*4 @@ -200,26 +197,26 @@ local function b64length(len) return a + l*2 end -local req_cb = socket.callback.send.chain( - socket.callback.send.file(io.open(index_file, "r")), - socket.mime.chain( - socket.mime.encode("base64"), - socket.mime.wrap("base64") +local source = ltn12.source.chain( + ltn12.source.file(io.open(index_file, "r")), + ltn12.filter.chain( + mime.encode("base64"), + mime.wrap("base64") ) ) -local resp_cb = socket.callback.receive.chain( - socket.mime.decode("base64"), - socket.callback.receive.file(io.open(index_file .. "-back", "w")) +local sink = ltn12.sink.chain( + mime.decode("base64"), + ltn12.sink.file(io.open(index_file .. "-back", "w")) ) request = { url = "http://" .. host .. cgiprefix .. "/cat", method = "POST", - body_cb = req_cb, + source = source, + sink = sink, headers = { ["content-length"] = b64length(string.len(index)) } } -response = { body_cb = resp_cb } expect = { code = 200 } @@ -228,7 +225,7 @@ ignore = { status = 1, headers = 1 } -check_request_cb(request, response, expect, ignore) +check_request_cb(request, expect, ignore) back = readfile(index_file .. "-back") check(back == index) os.remove(index_file .. "-back") @@ -362,7 +359,7 @@ io.write("testing manual basic auth: ") request = { url = "http://" .. host .. prefix .. "/auth/index.html", headers = { - authorization = "Basic " .. (socket.mime.b64("luasocket:password")) + authorization = "Basic " .. (mime.b64("luasocket:password")) } } expect = { diff --git a/test/mimetest.lua b/test/mimetest.lua index 1a7e427..4a0a20a 100644 --- a/test/mimetest.lua +++ b/test/mimetest.lua @@ -31,18 +31,27 @@ local mao = [[ assim, nem tudo o que dava exprimia grande confiança. ]] +local function random(handle, io_err) + if handle then + return function() + local chunk = handle:read(math.random(0, 1024)) + if not chunk then handle:close() end + return chunk + end + else source.empty(io_err or "unable to open file") end +end + +local what = nil local function transform(input, output, filter) - local fi, err = io.open(input, "rb") - if not fi then fail(err) end - local fo, err = io.open(output, "wb") - if not fo then fail(err) end - while 1 do - local chunk = fi:read(math.random(0, 1024)) - fo:write(filter(chunk)) - if not chunk then break end - end - fi:close() - fo:close() + local source = random(io.open(input, "rb")) + local sink = ltn12.sink.file(io.open(output, "wb")) + if what then + sink = ltn12.sink.chain(filter, sink) + else + source = ltn12.source.chain(source, filter) + end + --what = not what + ltn12.pump(source, sink) end local function encode_qptest(mode) diff --git a/test/stufftest.lua b/test/stufftest.lua new file mode 100644 index 0000000..5eb8005 --- /dev/null +++ b/test/stufftest.lua @@ -0,0 +1,19 @@ +function test_dot(original, right) + local result, n = socket.smtp.dot(2, original) + assert(result == right, "->" .. result .. "<-") + print("ok") +end + +function test_stuff(original, right) + local result, n = socket.smtp.dot(2, original) + assert(result == right, "->" .. result .. "<-") + print("ok") +end + +test_dot("abc", "abc") +test_dot("", "") +test_dot("\r\n", "\r\n") +test_dot("\r\n.", "\r\n..") +test_dot(".\r\n.", "..\r\n..") +test_dot(".\r\n.", "..\r\n..") +test_dot("abcd.\r\n.", "abcd.\r\n..")