From e5a090b01cd5b490f7331235b7c58c36c05bb7b6 Mon Sep 17 00:00:00 2001 From: Diego Nehab Date: Fri, 26 Mar 2004 06:05:20 +0000 Subject: [PATCH] Using socket pumps in http.lua. Adjusted socket.try. --- TODO | 1 + etc/b64.lua | 2 +- src/http.lua | 138 ++++++++++------------------------------------ src/mime.lua | 22 +++----- test/httptest.lua | 2 + 5 files changed, 41 insertions(+), 124 deletions(-) diff --git a/TODO b/TODO index 79a15af..1e30a78 100644 --- a/TODO +++ b/TODO @@ -27,6 +27,7 @@ falar sobre encodet/wrapt/decodet no manual sobre mime RECEIVE MUDOU!!! COLOCAR NO MANUAL. HTTP.lua mudou bastante também. +fazer com que a socket.source e socket.sink sejam "selectable". change mime.eol to output marker on detection of first candidate, instead of on the second. that way it works in one pass for strings that end with diff --git a/etc/b64.lua b/etc/b64.lua index ea157c4..3f861b4 100644 --- a/etc/b64.lua +++ b/etc/b64.lua @@ -9,4 +9,4 @@ else convert = ltn12.filter.chain(base64, wrap) end source = ltn12.source.chain(source, convert) -ltn12.pump(source, sink) +repeat until not ltn12.pump(source, sink) diff --git a/src/http.lua b/src/http.lua index 8b06184..da18aaf 100644 --- a/src/http.lua +++ b/src/http.lua @@ -51,7 +51,7 @@ local function receive_headers(reqt, respt, tmp) while line ~= "" do -- get field-name and value _, _, name, value = string.find(line, "^(.-):%s*(.*)") - assert(name and value, "malformed reponse headers") + socket.try(name and value, "malformed reponse headers") name = string.lower(name) -- get next line (value might be folded) line = socket.try(sock:receive()) @@ -66,119 +66,32 @@ local function receive_headers(reqt, respt, tmp) end end -local function abort(cb, err) - local go, cb_err = cb(nil, err) - error(cb_err or err) -end - -local function hand(cb, chunk) - local go, cb_err = cb(chunk) - assert(go, cb_err or "aborted by callback") -end - -local function receive_body_bychunks(sock, sink) - while 1 do - -- get chunk size, skip extention - local line, err = sock:receive() - if err then abort(sink, err) end - local size = tonumber(string.gsub(line, ";.*", ""), 16) - if not size then abort(sink, "invalid chunk size") end - -- was it the last chunk? - if size <= 0 then break end - -- get chunk - local chunk, err = sock:receive(size) - if err then abort(sink, err) end - -- pass chunk to callback - hand(sink, chunk) - -- skip CRLF on end of chunk - err = second(sock:receive()) - if err then abort(sink, err) end - end - -- let callback know we are done - hand(sink, nil) - -- servers shouldn't send trailer headers, but who trusts them? - local line = socket.try(sock:receive()) - while line ~= "" do - line = socket.try(sock:receive()) - end -end - -local function receive_body_bylength(sock, length, sink) - while length > 0 do - local size = math.min(BLOCKSIZE, length) - local chunk, err = sock:receive(size) - if err then abort(sink, err) end - length = length - string.len(chunk) - -- see if there was an error - hand(sink, chunk) - end - -- let callback know we are done - hand(sink, nil) -end - -local function receive_body_untilclosed(sock, sink) - while true do - local chunk, err, partial = sock:receive(BLOCKSIZE) - -- see if we are done - if err == "closed" then - hand(sink, partial) - break - end - hand(sink, chunk) - -- see if there was an error - if err then abort(sink, err) end - end - -- let callback know we are done - hand(sink, nil) -end - local function receive_body(reqt, respt, tmp) local sink = reqt.sink or ltn12.sink.null() - local headers = respt.headers - local sock = tmp.sock - local te = headers["transfer-encoding"] + local pump = reqt.pump or ltn12.pump + local source + local te = respt.headers["transfer-encoding"] if te and te ~= "identity" then -- get by chunked transfer-coding of message body - receive_body_bychunks(sock, sink) - elseif tonumber(headers["content-length"]) then + source = socket.source("http-chunked", tmp.sock) + elseif tonumber(respt.headers["content-length"]) then -- get by content-length - local length = tonumber(headers["content-length"]) - receive_body_bylength(sock, length, sink) + local length = tonumber(respt.headers["content-length"]) + source = socket.source("by-length", tmp.sock, length) else -- get it all until connection closes - receive_body_untilclosed(sock, sink) - end -end - -local function send_body_bychunks(data, source) - while true do - local chunk, err = source() - assert(chunk or not err, err) - if not chunk then break end - socket.try(data:send(string.format("%X\r\n", string.len(chunk)))) - socket.try(data:send(chunk, "\r\n")) - end - socket.try(data:send("0\r\n\r\n")) -end - -local function send_body(data, source) - while true do - local chunk, err = source() - assert(chunk or not err, err) - if not chunk then break end - socket.try(data:send(chunk)) + source = socket.source("until-closed", tmp.sock) end + socket.try(pump(source, sink)) end local function send_headers(sock, headers) -- send request headers for i, v in pairs(headers) do socket.try(sock:send(i .. ": " .. v .. "\r\n")) ---io.write(i .. ": " .. v .. "\r\n") end -- mark end of request headers socket.try(sock:send("\r\n")) ---io.write("\r\n") end local function should_receive_body(reqt, respt, tmp) @@ -211,22 +124,21 @@ end local function send_request(reqt, respt, tmp) local uri = request_uri(reqt, respt, tmp) - local sock = tmp.sock local headers = tmp.headers + local pump = reqt.pump or ltn12.pump -- send request line - socket.try(sock:send((reqt.method or "GET") + socket.try(tmp.sock:send((reqt.method or "GET") .. " " .. uri .. " HTTP/1.1\r\n")) ---io.write((reqt.method or "GET") - --.. " " .. uri .. " HTTP/1.1\r\n") - -- send request headers headeres if reqt.source and not headers["content-length"] then headers["transfer-encoding"] = "chunked" end - send_headers(sock, headers) + send_headers(tmp.sock, headers) -- send request message body, if any - if reqt.source then - if headers["content-length"] then send_body(sock, reqt.source) - else send_body_bychunks(sock, reqt.source) end + if not reqt.source then return end + if headers["content-length"] then + socket.try(pump(reqt.source, socket.sink(tmp.sock))) + else + socket.try(pump(reqt.source, socket.sink("http-chunked", tmp.sock))) end end @@ -235,7 +147,7 @@ local function open(reqt, respt, tmp) local host, port if proxy then local pproxy = socket.url.parse(proxy) - assert(pproxy.port and pproxy.host, "invalid proxy") + socket.try(pproxy.port and pproxy.host, "invalid proxy") host, port = pproxy.host, pproxy.port else host, port = tmp.parsed.host, tmp.parsed.port @@ -271,9 +183,8 @@ local function parse_url(reqt, respt, tmp) scheme = "http" }) -- scheme has to be http - if parsed.scheme ~= "http" then - error(string.format("unknown scheme '%s'", parsed.scheme)) - end + socket.try(parsed.scheme == "http", + string.format("unknown scheme '%s'", parsed.scheme)) -- explicit authentication info overrides that given by the URL parsed.user = reqt.user or parsed.user parsed.password = reqt.password or parsed.password @@ -342,6 +253,12 @@ local function redirect(reqt, respt, tmp) if respt.headers then respt.headers.location = redirt.url end end +local function skip_continue(reqt, respt, tmp) + if respt.code == 100 then + receive_status(reqt, respt, tmp) + end +end + -- execute a request of through an exception function request_p(reqt, respt, tmp) parse_url(reqt, respt, tmp) @@ -349,6 +266,7 @@ function request_p(reqt, respt, tmp) open(reqt, respt, tmp) send_request(reqt, respt, tmp) receive_status(reqt, respt, tmp) + skip_continue(reqt, respt, tmp) receive_headers(reqt, respt, tmp) if should_redirect(reqt, respt, tmp) then tmp.sock:close() diff --git a/src/mime.lua b/src/mime.lua index 8c2a5c0..d263d48 100644 --- a/src/mime.lua +++ b/src/mime.lua @@ -11,11 +11,14 @@ decodet = {} wrapt = {} -- creates a function that chooses a filter by name from a given table -local function choose(table) - return function(name, opt) +function choose(table) + return function(name, opt1, opt2) + if type(name) ~= "string" then + name, opt1, opt2 = "default", name, opt1 + end local f = table[name or "nil"] - if not f then error("unknown filter (" .. tostring(name) .. ")", 3) - else return f(opt) end + if not f then error("unknown key (" .. tostring(name) .. ")", 3) + else return f(opt1, opt2) end end end @@ -44,6 +47,7 @@ wrapt['text'] = function(length) return ltn12.filter.cycle(wrp, length, length) end wrapt['base64'] = wrapt['text'] +wrapt['default'] = wrapt['text'] wrapt['quoted-printable'] = function() return ltn12.filter.cycle(qpwrp, 76, 76) @@ -52,15 +56,7 @@ end -- function that choose the encoding, decoding or wrap algorithm encode = choose(encodet) decode = choose(decodet) --- it's different because there is a default wrap filter -local cwt = choose(wrapt) -function wrap(mode_or_length, length) - if type(mode_or_length) ~= "string" then - length = mode_or_length - mode_or_length = "text" - end - return cwt(mode_or_length, length) -end +wrap = choose(wrapt) -- define the end-of-line normalization filter function normalize(marker) diff --git a/test/httptest.lua b/test/httptest.lua index 4ec0cc1..c16ef2b 100644 --- a/test/httptest.lua +++ b/test/httptest.lua @@ -411,6 +411,7 @@ local body io.write("testing simple get function: ") body = socket.http.get("http://" .. host .. prefix .. "/index.html") assert(body == index) +print("ok") ------------------------------------------------------------------------ io.write("testing HEAD method: ") @@ -420,6 +421,7 @@ response = socket.http.request { url = "http://www.cs.princeton.edu/~diego/" } assert(response and response.headers) +print("ok") ------------------------------------------------------------------------ print("passed all tests")