Using socket pumps in http.lua.

Adjusted socket.try.
This commit is contained in:
Diego Nehab 2004-03-26 06:05:20 +00:00
parent e77f179200
commit e5a090b01c
5 changed files with 41 additions and 124 deletions

1
TODO
View File

@ -27,6 +27,7 @@ falar sobre encodet/wrapt/decodet no manual sobre mime
RECEIVE MUDOU!!! COLOCAR NO MANUAL. RECEIVE MUDOU!!! COLOCAR NO MANUAL.
HTTP.lua mudou bastante também. HTTP.lua mudou bastante também.
fazer com que a socket.source e socket.sink sejam "selectable".
change mime.eol to output marker on detection of first candidate, instead change mime.eol to output marker on detection of first candidate, instead
of on the second. that way it works in one pass for strings that end with of on the second. that way it works in one pass for strings that end with

View File

@ -9,4 +9,4 @@ else
convert = ltn12.filter.chain(base64, wrap) convert = ltn12.filter.chain(base64, wrap)
end end
source = ltn12.source.chain(source, convert) source = ltn12.source.chain(source, convert)
ltn12.pump(source, sink) repeat until not ltn12.pump(source, sink)

View File

@ -51,7 +51,7 @@ local function receive_headers(reqt, respt, tmp)
while line ~= "" do while line ~= "" do
-- get field-name and value -- get field-name and value
_, _, name, value = string.find(line, "^(.-):%s*(.*)") _, _, name, value = string.find(line, "^(.-):%s*(.*)")
assert(name and value, "malformed reponse headers") socket.try(name and value, "malformed reponse headers")
name = string.lower(name) name = string.lower(name)
-- get next line (value might be folded) -- get next line (value might be folded)
line = socket.try(sock:receive()) line = socket.try(sock:receive())
@ -66,119 +66,32 @@ local function receive_headers(reqt, respt, tmp)
end end
end end
local function abort(cb, err)
local go, cb_err = cb(nil, err)
error(cb_err or err)
end
local function hand(cb, chunk)
local go, cb_err = cb(chunk)
assert(go, cb_err or "aborted by callback")
end
local function receive_body_bychunks(sock, sink)
while 1 do
-- get chunk size, skip extention
local line, err = sock:receive()
if err then abort(sink, err) end
local size = tonumber(string.gsub(line, ";.*", ""), 16)
if not size then abort(sink, "invalid chunk size") end
-- was it the last chunk?
if size <= 0 then break end
-- get chunk
local chunk, err = sock:receive(size)
if err then abort(sink, err) end
-- pass chunk to callback
hand(sink, chunk)
-- skip CRLF on end of chunk
err = second(sock:receive())
if err then abort(sink, err) end
end
-- let callback know we are done
hand(sink, nil)
-- servers shouldn't send trailer headers, but who trusts them?
local line = socket.try(sock:receive())
while line ~= "" do
line = socket.try(sock:receive())
end
end
local function receive_body_bylength(sock, length, sink)
while length > 0 do
local size = math.min(BLOCKSIZE, length)
local chunk, err = sock:receive(size)
if err then abort(sink, err) end
length = length - string.len(chunk)
-- see if there was an error
hand(sink, chunk)
end
-- let callback know we are done
hand(sink, nil)
end
local function receive_body_untilclosed(sock, sink)
while true do
local chunk, err, partial = sock:receive(BLOCKSIZE)
-- see if we are done
if err == "closed" then
hand(sink, partial)
break
end
hand(sink, chunk)
-- see if there was an error
if err then abort(sink, err) end
end
-- let callback know we are done
hand(sink, nil)
end
local function receive_body(reqt, respt, tmp) local function receive_body(reqt, respt, tmp)
local sink = reqt.sink or ltn12.sink.null() local sink = reqt.sink or ltn12.sink.null()
local headers = respt.headers local pump = reqt.pump or ltn12.pump
local sock = tmp.sock local source
local te = headers["transfer-encoding"] local te = respt.headers["transfer-encoding"]
if te and te ~= "identity" then if te and te ~= "identity" then
-- get by chunked transfer-coding of message body -- get by chunked transfer-coding of message body
receive_body_bychunks(sock, sink) source = socket.source("http-chunked", tmp.sock)
elseif tonumber(headers["content-length"]) then elseif tonumber(respt.headers["content-length"]) then
-- get by content-length -- get by content-length
local length = tonumber(headers["content-length"]) local length = tonumber(respt.headers["content-length"])
receive_body_bylength(sock, length, sink) source = socket.source("by-length", tmp.sock, length)
else else
-- get it all until connection closes -- get it all until connection closes
receive_body_untilclosed(sock, sink) source = socket.source("until-closed", tmp.sock)
end
end
local function send_body_bychunks(data, source)
while true do
local chunk, err = source()
assert(chunk or not err, err)
if not chunk then break end
socket.try(data:send(string.format("%X\r\n", string.len(chunk))))
socket.try(data:send(chunk, "\r\n"))
end
socket.try(data:send("0\r\n\r\n"))
end
local function send_body(data, source)
while true do
local chunk, err = source()
assert(chunk or not err, err)
if not chunk then break end
socket.try(data:send(chunk))
end end
socket.try(pump(source, sink))
end end
local function send_headers(sock, headers) local function send_headers(sock, headers)
-- send request headers -- send request headers
for i, v in pairs(headers) do for i, v in pairs(headers) do
socket.try(sock:send(i .. ": " .. v .. "\r\n")) socket.try(sock:send(i .. ": " .. v .. "\r\n"))
--io.write(i .. ": " .. v .. "\r\n")
end end
-- mark end of request headers -- mark end of request headers
socket.try(sock:send("\r\n")) socket.try(sock:send("\r\n"))
--io.write("\r\n")
end end
local function should_receive_body(reqt, respt, tmp) local function should_receive_body(reqt, respt, tmp)
@ -211,22 +124,21 @@ end
local function send_request(reqt, respt, tmp) local function send_request(reqt, respt, tmp)
local uri = request_uri(reqt, respt, tmp) local uri = request_uri(reqt, respt, tmp)
local sock = tmp.sock
local headers = tmp.headers local headers = tmp.headers
local pump = reqt.pump or ltn12.pump
-- send request line -- send request line
socket.try(sock:send((reqt.method or "GET") socket.try(tmp.sock:send((reqt.method or "GET")
.. " " .. uri .. " HTTP/1.1\r\n")) .. " " .. uri .. " HTTP/1.1\r\n"))
--io.write((reqt.method or "GET")
--.. " " .. uri .. " HTTP/1.1\r\n")
-- send request headers headeres
if reqt.source and not headers["content-length"] then if reqt.source and not headers["content-length"] then
headers["transfer-encoding"] = "chunked" headers["transfer-encoding"] = "chunked"
end end
send_headers(sock, headers) send_headers(tmp.sock, headers)
-- send request message body, if any -- send request message body, if any
if reqt.source then if not reqt.source then return end
if headers["content-length"] then send_body(sock, reqt.source) if headers["content-length"] then
else send_body_bychunks(sock, reqt.source) end socket.try(pump(reqt.source, socket.sink(tmp.sock)))
else
socket.try(pump(reqt.source, socket.sink("http-chunked", tmp.sock)))
end end
end end
@ -235,7 +147,7 @@ local function open(reqt, respt, tmp)
local host, port local host, port
if proxy then if proxy then
local pproxy = socket.url.parse(proxy) local pproxy = socket.url.parse(proxy)
assert(pproxy.port and pproxy.host, "invalid proxy") socket.try(pproxy.port and pproxy.host, "invalid proxy")
host, port = pproxy.host, pproxy.port host, port = pproxy.host, pproxy.port
else else
host, port = tmp.parsed.host, tmp.parsed.port host, port = tmp.parsed.host, tmp.parsed.port
@ -271,9 +183,8 @@ local function parse_url(reqt, respt, tmp)
scheme = "http" scheme = "http"
}) })
-- scheme has to be http -- scheme has to be http
if parsed.scheme ~= "http" then socket.try(parsed.scheme == "http",
error(string.format("unknown scheme '%s'", parsed.scheme)) string.format("unknown scheme '%s'", parsed.scheme))
end
-- explicit authentication info overrides that given by the URL -- explicit authentication info overrides that given by the URL
parsed.user = reqt.user or parsed.user parsed.user = reqt.user or parsed.user
parsed.password = reqt.password or parsed.password parsed.password = reqt.password or parsed.password
@ -342,6 +253,12 @@ local function redirect(reqt, respt, tmp)
if respt.headers then respt.headers.location = redirt.url end if respt.headers then respt.headers.location = redirt.url end
end end
local function skip_continue(reqt, respt, tmp)
if respt.code == 100 then
receive_status(reqt, respt, tmp)
end
end
-- execute a request of through an exception -- execute a request of through an exception
function request_p(reqt, respt, tmp) function request_p(reqt, respt, tmp)
parse_url(reqt, respt, tmp) parse_url(reqt, respt, tmp)
@ -349,6 +266,7 @@ function request_p(reqt, respt, tmp)
open(reqt, respt, tmp) open(reqt, respt, tmp)
send_request(reqt, respt, tmp) send_request(reqt, respt, tmp)
receive_status(reqt, respt, tmp) receive_status(reqt, respt, tmp)
skip_continue(reqt, respt, tmp)
receive_headers(reqt, respt, tmp) receive_headers(reqt, respt, tmp)
if should_redirect(reqt, respt, tmp) then if should_redirect(reqt, respt, tmp) then
tmp.sock:close() tmp.sock:close()

View File

@ -11,11 +11,14 @@ decodet = {}
wrapt = {} wrapt = {}
-- creates a function that chooses a filter by name from a given table -- creates a function that chooses a filter by name from a given table
local function choose(table) function choose(table)
return function(name, opt) return function(name, opt1, opt2)
if type(name) ~= "string" then
name, opt1, opt2 = "default", name, opt1
end
local f = table[name or "nil"] local f = table[name or "nil"]
if not f then error("unknown filter (" .. tostring(name) .. ")", 3) if not f then error("unknown key (" .. tostring(name) .. ")", 3)
else return f(opt) end else return f(opt1, opt2) end
end end
end end
@ -44,6 +47,7 @@ wrapt['text'] = function(length)
return ltn12.filter.cycle(wrp, length, length) return ltn12.filter.cycle(wrp, length, length)
end end
wrapt['base64'] = wrapt['text'] wrapt['base64'] = wrapt['text']
wrapt['default'] = wrapt['text']
wrapt['quoted-printable'] = function() wrapt['quoted-printable'] = function()
return ltn12.filter.cycle(qpwrp, 76, 76) return ltn12.filter.cycle(qpwrp, 76, 76)
@ -52,15 +56,7 @@ end
-- function that choose the encoding, decoding or wrap algorithm -- function that choose the encoding, decoding or wrap algorithm
encode = choose(encodet) encode = choose(encodet)
decode = choose(decodet) decode = choose(decodet)
-- it's different because there is a default wrap filter wrap = choose(wrapt)
local cwt = choose(wrapt)
function wrap(mode_or_length, length)
if type(mode_or_length) ~= "string" then
length = mode_or_length
mode_or_length = "text"
end
return cwt(mode_or_length, length)
end
-- define the end-of-line normalization filter -- define the end-of-line normalization filter
function normalize(marker) function normalize(marker)

View File

@ -411,6 +411,7 @@ local body
io.write("testing simple get function: ") io.write("testing simple get function: ")
body = socket.http.get("http://" .. host .. prefix .. "/index.html") body = socket.http.get("http://" .. host .. prefix .. "/index.html")
assert(body == index) assert(body == index)
print("ok")
------------------------------------------------------------------------ ------------------------------------------------------------------------
io.write("testing HEAD method: ") io.write("testing HEAD method: ")
@ -420,6 +421,7 @@ response = socket.http.request {
url = "http://www.cs.princeton.edu/~diego/" url = "http://www.cs.princeton.edu/~diego/"
} }
assert(response and response.headers) assert(response and response.headers)
print("ok")
------------------------------------------------------------------------ ------------------------------------------------------------------------
print("passed all tests") print("passed all tests")