diff --git a/TODO b/TODO index 110a78c..4fe107b 100644 --- a/TODO +++ b/TODO @@ -19,6 +19,13 @@ * Separar as classes em arquivos * Retorno de sendto em datagram sockets pode ser refused +falar sobre encodet/wrapt/decodet no manual sobre mime + + +RECEIVE MUDOU!!! COLOCAR NO MANUAL. +HTTP.lua mudou bastante também. + + change mime.eol to output marker on detection of first candidate, instead of on the second. that way it works in one pass for strings that end with one candidate. diff --git a/etc/eol.lua b/etc/eol.lua index 6b2a8a9..aa43596 100644 --- a/etc/eol.lua +++ b/etc/eol.lua @@ -1,9 +1,6 @@ -marker = {['-u'] = '\10', ['-d'] = '\13\10'} -arg = arg or {'-u'} -marker = marker[arg[1]] or marker['-u'] -local convert = socket.mime.normalize(marker) -while 1 do - local chunk = io.read(1) - io.write(convert(chunk)) - if not chunk then break end -end +local marker = '\n' +if arg and arg[1] == '-d' then marker = '\r\n' end +local filter = mime.normalize(marker) +local source = ltn12.source.chain(ltn12.source.file(io.stdin), filter) +local sink = ltn12.sink.file(io.stdout) +ltn12.pump(source, sink) diff --git a/etc/get.lua b/etc/get.lua index 0306b54..eafebda 100644 --- a/etc/get.lua +++ b/etc/get.lua @@ -93,7 +93,7 @@ function getbyhttp(url, file) local save = ltn12.sink.file(file or io.stdout) -- only print feedback if output is not stdout if file then save = ltn12.sink.chain(stats(gethttpsize(url)), save) end - local respt = socket.http.request_cb({url = url, sink = save}) + local respt = socket.http.request {url = url, sink = save } if respt.code ~= 200 then print(respt.status or respt.error) end end @@ -103,7 +103,7 @@ function getbyftp(url, file) -- only print feedback if output is not stdout -- and we don't know how big the file is if file then save = ltn12.sink.chain(stats(), save) end - local ret, err = socket.ftp.get_cb {url = url, sink = save, type = "i"} + local ret, err = socket.ftp.get {url = url, sink = save, type = "i"} if err then print(err) end end diff --git a/etc/qp.lua b/etc/qp.lua index 1ca0ae2..08545db 100644 --- a/etc/qp.lua +++ b/etc/qp.lua @@ -2,17 +2,15 @@ local convert arg = arg or {} local mode = arg and arg[1] or "-et" if mode == "-et" then - local normalize = socket.mime.normalize() - local qp = socket.mime.encode("quoted-printable") - local wrap = socket.mime.wrap("quoted-printable") - convert = socket.mime.chain(normalize, qp, wrap) + local normalize = mime.normalize() + local qp = mime.encode("quoted-printable") + local wrap = mime.wrap("quoted-printable") + convert = ltn12.filter.chain(normalize, qp, wrap) elseif mode == "-eb" then - local qp = socket.mime.encode("quoted-printable", "binary") - local wrap = socket.mime.wrap("quoted-printable") - convert = socket.mime.chain(qp, wrap) -else convert = socket.mime.decode("quoted-printable") end -while 1 do - local chunk = io.read(4096) - io.write(convert(chunk)) - if not chunk then break end -end + local qp = mime.encode("quoted-printable", "binary") + local wrap = mime.wrap("quoted-printable") + convert = ltn12.filter.chain(qp, wrap) +else convert = mime.decode("quoted-printable") end +local source = ltn12.source.chain(ltn12.source.file(io.stdin), convert) +local sink = ltn12.sink.file(io.stdout) +ltn12.pump(source, sink) diff --git a/samples/daytimeclnt.lua b/samples/daytimeclnt.lua index 29abe17..63f4017 100644 --- a/samples/daytimeclnt.lua +++ b/samples/daytimeclnt.lua @@ -14,8 +14,9 @@ host = socket.dns.toip(host) udp = socket.udp() print("Using host '" ..host.. "' and port " ..port.. "...") udp:setpeername(host, port) +udp:settimeout(3) sent, err = udp:send("anything") -if err then print(err) exit() end +if err then print(err) os.exit() end dgram, err = udp:receive() -if not dgram then print(err) exit() end +if not dgram then print(err) os.exit() end io.write(dgram) diff --git a/src/http.lua b/src/http.lua index a10cf50..ab166e3 100644 --- a/src/http.lua +++ b/src/http.lua @@ -2,7 +2,7 @@ -- HTTP/1.1 client support for the Lua language. -- LuaSocket toolkit. -- Author: Diego Nehab --- Conforming to: RFC 2616, LTN7 +-- Conforming to RFC 2616 -- RCS ID: $Id$ ----------------------------------------------------------------------------- -- make sure LuaSocket is loaded @@ -39,21 +39,18 @@ local function third(a, b, c) return c end -local function shift(a, b, c, d) - return c, d -end - --- resquest_p forward declaration -local request_p - -local function receive_headers(sock, headers) - local line, name, value +local function receive_headers(reqt, respt) + local headers = {} + local sock = respt.tmp.sock + local line, name, value, _ + -- store results + respt.headers = headers -- get first line line = socket.try(sock:receive()) -- headers go until a blank line is found while line ~= "" do -- get field-name and value - name, value = shift(string.find(line, "^(.-):%s*(.*)")) + _, _, name, value = string.find(line, "^(.-):%s*(.*)") assert(name and value, "malformed reponse headers") name = string.lower(name) -- get next line (value might be folded) @@ -100,7 +97,10 @@ local function receive_body_bychunks(sock, sink) -- let callback know we are done hand(sink, nil) -- servers shouldn't send trailer headers, but who trusts them? - receive_headers(sock, {}) + local line = socket.try(sock:receive()) + while line ~= "" do + line = socket.try(sock:receive()) + end end local function receive_body_bylength(sock, length, sink) @@ -245,7 +245,7 @@ local function open(reqt, respt) socket.try(sock:connect(host, port)) end -function adjust_headers(reqt, respt) +local function adjust_headers(reqt, respt) local lower = {} local headers = reqt.headers or {} -- set default headers @@ -261,7 +261,7 @@ function adjust_headers(reqt, respt) respt.tmp.headers = lower end -function parse_url(reqt, respt) +local function parse_url(reqt, respt) -- parse url with default fields local parsed = socket.url.parse(reqt.url, { host = "", @@ -280,11 +280,16 @@ function parse_url(reqt, respt) respt.tmp.parsed = parsed end +-- forward declaration +local request_p + local function should_authorize(reqt, respt) -- if there has been an authorization attempt, it must have failed if reqt.headers and reqt.headers["authorization"] then return nil end - -- if we don't have authorization information, we can't retry - return respt.tmp.parsed.user and respt.tmp.parsed.password + -- if last attempt didn't fail due to lack of authentication, + -- or we don't have authorization information, we can't retry + return respt.code == 401 and + respt.tmp.parsed.user and respt.tmp.parsed.password end local function clone(headers) @@ -338,14 +343,14 @@ local function redirect(reqt, respt) if respt.headers then respt.headers.location = redirt.url end end +-- execute a request of through an exception function request_p(reqt, respt) parse_url(reqt, respt) adjust_headers(reqt, respt) open(reqt, respt) send_request(reqt, respt) receive_status(reqt, respt) - respt.headers = {} - receive_headers(respt.tmp.sock, respt.headers) + receive_headers(reqt, respt) if should_redirect(reqt, respt) then respt.tmp.sock:close() redirect(reqt, respt) diff --git a/src/ltn12.lua b/src/ltn12.lua index dc49d80..ed3449b 100644 --- a/src/ltn12.lua +++ b/src/ltn12.lua @@ -22,6 +22,7 @@ end -- returns a high level filter that cycles a cycles a low-level filter function filter.cycle(low, ctx, extra) + if type(low) ~= 'function' then error('invalid low-level filter', 2) end return function(chunk) local ret ret, ctx = low(ctx, chunk, extra) @@ -31,6 +32,8 @@ end -- chains two filters together local function chain2(f1, f2) + if type(f1) ~= 'function' then error('invalid filter', 2) end + if type(f2) ~= 'function' then error('invalid filter', 2) end return function(chunk) return f2(f1(chunk)) end @@ -40,6 +43,7 @@ end function filter.chain(...) local f = arg[1] for i = 2, table.getn(arg) do + if type(arg[i]) ~= 'function' then error('invalid filter', 2) end f = chain2(f, arg[i]) end return f @@ -74,6 +78,7 @@ end -- turns a fancy source into a simple source function source.simplify(src) + if type(src) ~= 'function' then error('invalid source', 2) end return function() local chunk, err_or_new = src() src = err_or_new or src @@ -97,6 +102,7 @@ end -- creates rewindable source function source.rewind(src) + if type(src) ~= 'function' then error('invalid source', 2) end local t = {} return function(chunk) if not chunk then @@ -111,6 +117,8 @@ end -- chains a source with a filter function source.chain(src, f) + if type(src) ~= 'function' then error('invalid source', 2) end + if type(f) ~= 'function' then error('invalid filter', 2) end local co = coroutine.create(function() while true do local chunk, err = src() @@ -157,6 +165,7 @@ end -- turns a fancy sink into a simple sink function sink.simplify(snk) + if type(snk) ~= 'function' then error('invalid sink', 2) end return function(chunk, err) local ret, err_or_new = snk(chunk, err) if not ret then return nil, err_or_new end @@ -195,6 +204,8 @@ end -- chains a sink with a filter function sink.chain(f, snk) + if type(snk) ~= 'function' then error('invalid sink', 2) end + if type(f) ~= 'function' then error('invalid filter', 2) end return function(chunk, err) local filtered = f(chunk) local done = chunk and "" @@ -209,6 +220,8 @@ end -- pumps all data from a source to a sink function pump(src, snk) + if type(src) ~= 'function' then error('invalid source', 2) end + if type(snk) ~= 'function' then error('invalid sink', 2) end while true do local chunk, src_err = src() local ret, snk_err = snk(chunk, src_err) diff --git a/src/smtp.lua b/src/smtp.lua index c823c97..ed8bd15 100644 --- a/src/smtp.lua +++ b/src/smtp.lua @@ -20,16 +20,17 @@ DOMAIN = os.getenv("SERVER_NAME") or "localhost" -- default time zone (means we don't know) ZONE = "-0000" -function stuff() - return ltn12.filter.cycle(dot, 2) -end - local function shift(a, b, c) return b, c end +-- high level stuffing filter +function stuff() + return ltn12.filter.cycle(dot, 2) +end + -- send message or throw an exception -function psend(control, mailt) +local function send_p(control, mailt) socket.try(control:check("2..")) socket.try(control:command("EHLO", mailt.domain or DOMAIN)) socket.try(control:check("2..")) @@ -61,11 +62,11 @@ local function newboundary() math.random(0, 99999), seqno) end --- sendmessage forward declaration -local sendmessage +-- send_message forward declaration +local send_message -- yield multipart message body from a multipart message table -local function sendmultipart(mesgt) +local function send_multipart(mesgt) local bd = newboundary() -- define boundary and finish headers coroutine.yield('content-type: multipart/mixed; boundary="' .. @@ -75,7 +76,7 @@ local function sendmultipart(mesgt) -- send each part separated by a boundary for i, m in ipairs(mesgt.body) do coroutine.yield("\r\n--" .. bd .. "\r\n") - sendmessage(m) + send_message(m) end -- send last boundary coroutine.yield("\r\n--" .. bd .. "--\r\n\r\n") @@ -84,7 +85,7 @@ local function sendmultipart(mesgt) end -- yield message body from a source -local function sendsource(mesgt) +local function send_source(mesgt) -- set content-type if user didn't override if not mesgt.headers or not mesgt.headers["content-type"] then coroutine.yield('content-type: text/plain; charset="iso-8859-1"\r\n') @@ -101,7 +102,7 @@ local function sendsource(mesgt) end -- yield message body from a string -local function sendstring(mesgt) +local function send_string(mesgt) -- set content-type if user didn't override if not mesgt.headers or not mesgt.headers["content-type"] then coroutine.yield('content-type: text/plain; charset="iso-8859-1"\r\n') @@ -114,7 +115,7 @@ local function sendstring(mesgt) end -- yield the headers one by one -local function sendheaders(mesgt) +local function send_headers(mesgt) if mesgt.headers then for i,v in pairs(mesgt.headers) do coroutine.yield(i .. ':' .. v .. "\r\n") @@ -123,15 +124,15 @@ local function sendheaders(mesgt) end -- message source -function sendmessage(mesgt) - sendheaders(mesgt) - if type(mesgt.body) == "table" then sendmultipart(mesgt) - elseif type(mesgt.body) == "function" then sendsource(mesgt) - else sendstring(mesgt) end +function send_message(mesgt) + send_headers(mesgt) + if type(mesgt.body) == "table" then send_multipart(mesgt) + elseif type(mesgt.body) == "function" then send_source(mesgt) + else send_string(mesgt) end end -- set defaul headers -local function adjustheaders(mesgt) +local function adjust_headers(mesgt) mesgt.headers = mesgt.headers or {} mesgt.headers["mime-version"] = "1.0" mesgt.headers["date"] = mesgt.headers["date"] or @@ -140,16 +141,16 @@ local function adjustheaders(mesgt) end function message(mesgt) - adjustheaders(mesgt) + adjust_headers(mesgt) -- create and return message source - local co = coroutine.create(function() sendmessage(mesgt) end) + local co = coroutine.create(function() send_message(mesgt) end) return function() return shift(coroutine.resume(co)) end end function send(mailt) local c, e = socket.tp.connect(mailt.server or SERVER, mailt.port or PORT) if not c then return nil, e end - local s, e = pcall(psend, c, mailt) + local s, e = pcall(send_p, c, mailt) c:close() if s then return true else return nil, e end diff --git a/test/httptest.lua b/test/httptest.lua index ddeea50..4ec0cc1 100644 --- a/test/httptest.lua +++ b/test/httptest.lua @@ -3,7 +3,7 @@ -- needs ScriptAlias from /home/c/diego/tec/luasocket/test/cgi -- to "/luasocket-test-cgi" and "/luasocket-test-cgi/" -- needs "AllowOverride AuthConfig" on /home/c/diego/tec/luasocket/test/auth -dofile("noglobals.lua") +dofile("testsupport.lua") local host, proxy, request, response, index_file local ignore, expect, index, prefix, cgiprefix, index_crlf @@ -18,33 +18,9 @@ prefix = prefix or "/luasocket-test" cgiprefix = cgiprefix or "/luasocket-test-cgi" index_file = "test/index.html" -local readfile = function(name) - local f = io.open(name, "r") - if not f then return nil end - local s = f:read("*a") - f:close() - return s -end - -- read index with CRLF convention index = readfile(index_file) -local similar = function(s1, s2) - return string.lower(string.gsub(s1 or "", "%s", "")) == - string.lower(string.gsub(s2 or "", "%s", "")) -end - -local fail = function(s) - s = s or "failed!" - print(s) - os.exit() -end - -local check = function (v, e) - if v then print("ok") - else fail(e) end -end - local check_result = function(response, expect, ignore) for i,v in response do if not ignore[i] then @@ -171,7 +147,7 @@ check_request(request, expect, ignore) ------------------------------------------------------------------------ io.write("testing simple post function: ") back = socket.http.post("http://" .. host .. cgiprefix .. "/cat", index) -check(back == index) +assert(back == index) ------------------------------------------------------------------------ io.write("testing ltn12.(sink|source).file: ") @@ -191,7 +167,7 @@ ignore = { } check_request(request, expect, ignore) back = readfile(index_file .. "-back") -check(back == index) +assert(back == index) os.remove(index_file .. "-back") ------------------------------------------------------------------------ @@ -233,7 +209,7 @@ ignore = { } check_request(request, expect, ignore) back = readfile(index_file .. "-back") -check(back == index) +assert(back == index) os.remove(index_file .. "-back") ------------------------------------------------------------------------ @@ -434,7 +410,7 @@ check_request(request, expect, ignore) local body io.write("testing simple get function: ") body = socket.http.get("http://" .. host .. prefix .. "/index.html") -check(body == index) +assert(body == index) ------------------------------------------------------------------------ io.write("testing HEAD method: ") @@ -443,7 +419,7 @@ response = socket.http.request { method = "HEAD", url = "http://www.cs.princeton.edu/~diego/" } -check(response and response.headers) +assert(response and response.headers) ------------------------------------------------------------------------ print("passed all tests") diff --git a/test/smtptest.lua b/test/smtptest.lua index 8468408..e812737 100644 --- a/test/smtptest.lua +++ b/test/smtptest.lua @@ -16,7 +16,7 @@ local err dofile("mbox.lua") local parse = mbox.parse -dofile("noglobals.lua") +dofile("testsupport.lua") local total = function() local t = 0 diff --git a/test/testsupport.lua b/test/testsupport.lua new file mode 100644 index 0000000..ca3cd95 --- /dev/null +++ b/test/testsupport.lua @@ -0,0 +1,37 @@ +function readfile(name) + local f = io.open(name, "r") + if not f then return nil end + local s = f:read("*a") + f:close() + return s +end + +function similar(s1, s2) + return string.lower(string.gsub(s1 or "", "%s", "")) == + string.lower(string.gsub(s2 or "", "%s", "")) +end + +function fail(msg) + msg = msg or "failed" + error(msg, 2) +end + +function compare(input, output) + local original = readfile(input) + local recovered = readfile(output) + if original ~= recovered then fail("comparison failed") + else print("ok") end +end + +local G = _G +local set = rawset +local warn = print + +local setglobal = function(table, key, value) + warn("changed " .. key) + set(table, key, value) +end + +setmetatable(G, { + __newindex = setglobal +}) diff --git a/test/urltest.lua b/test/urltest.lua index 990a3e5..7e0e73f 100644 --- a/test/urltest.lua +++ b/test/urltest.lua @@ -1,4 +1,4 @@ -dofile("noglobals.lua") +dofile("testsupport.lua") local check_build_url = function(parsed) local built = socket.url.build(parsed)