Using socket pumps in http.lua.

Adjusted socket.try.
This commit is contained in:
Diego Nehab 2004-03-26 06:05:20 +00:00
parent e77f179200
commit e5a090b01c
5 changed files with 41 additions and 124 deletions

1
TODO
View File

@ -27,6 +27,7 @@ falar sobre encodet/wrapt/decodet no manual sobre mime
RECEIVE MUDOU!!! COLOCAR NO MANUAL.
HTTP.lua mudou bastante também.
fazer com que a socket.source e socket.sink sejam "selectable".
change mime.eol to output marker on detection of first candidate, instead
of on the second. that way it works in one pass for strings that end with

View File

@ -9,4 +9,4 @@ else
convert = ltn12.filter.chain(base64, wrap)
end
source = ltn12.source.chain(source, convert)
ltn12.pump(source, sink)
repeat until not ltn12.pump(source, sink)

View File

@ -51,7 +51,7 @@ local function receive_headers(reqt, respt, tmp)
while line ~= "" do
-- get field-name and value
_, _, name, value = string.find(line, "^(.-):%s*(.*)")
assert(name and value, "malformed reponse headers")
socket.try(name and value, "malformed reponse headers")
name = string.lower(name)
-- get next line (value might be folded)
line = socket.try(sock:receive())
@ -66,119 +66,32 @@ local function receive_headers(reqt, respt, tmp)
end
end
local function abort(cb, err)
local go, cb_err = cb(nil, err)
error(cb_err or err)
end
local function hand(cb, chunk)
local go, cb_err = cb(chunk)
assert(go, cb_err or "aborted by callback")
end
local function receive_body_bychunks(sock, sink)
while 1 do
-- get chunk size, skip extention
local line, err = sock:receive()
if err then abort(sink, err) end
local size = tonumber(string.gsub(line, ";.*", ""), 16)
if not size then abort(sink, "invalid chunk size") end
-- was it the last chunk?
if size <= 0 then break end
-- get chunk
local chunk, err = sock:receive(size)
if err then abort(sink, err) end
-- pass chunk to callback
hand(sink, chunk)
-- skip CRLF on end of chunk
err = second(sock:receive())
if err then abort(sink, err) end
end
-- let callback know we are done
hand(sink, nil)
-- servers shouldn't send trailer headers, but who trusts them?
local line = socket.try(sock:receive())
while line ~= "" do
line = socket.try(sock:receive())
end
end
local function receive_body_bylength(sock, length, sink)
while length > 0 do
local size = math.min(BLOCKSIZE, length)
local chunk, err = sock:receive(size)
if err then abort(sink, err) end
length = length - string.len(chunk)
-- see if there was an error
hand(sink, chunk)
end
-- let callback know we are done
hand(sink, nil)
end
local function receive_body_untilclosed(sock, sink)
while true do
local chunk, err, partial = sock:receive(BLOCKSIZE)
-- see if we are done
if err == "closed" then
hand(sink, partial)
break
end
hand(sink, chunk)
-- see if there was an error
if err then abort(sink, err) end
end
-- let callback know we are done
hand(sink, nil)
end
local function receive_body(reqt, respt, tmp)
local sink = reqt.sink or ltn12.sink.null()
local headers = respt.headers
local sock = tmp.sock
local te = headers["transfer-encoding"]
local pump = reqt.pump or ltn12.pump
local source
local te = respt.headers["transfer-encoding"]
if te and te ~= "identity" then
-- get by chunked transfer-coding of message body
receive_body_bychunks(sock, sink)
elseif tonumber(headers["content-length"]) then
source = socket.source("http-chunked", tmp.sock)
elseif tonumber(respt.headers["content-length"]) then
-- get by content-length
local length = tonumber(headers["content-length"])
receive_body_bylength(sock, length, sink)
local length = tonumber(respt.headers["content-length"])
source = socket.source("by-length", tmp.sock, length)
else
-- get it all until connection closes
receive_body_untilclosed(sock, sink)
end
end
local function send_body_bychunks(data, source)
while true do
local chunk, err = source()
assert(chunk or not err, err)
if not chunk then break end
socket.try(data:send(string.format("%X\r\n", string.len(chunk))))
socket.try(data:send(chunk, "\r\n"))
end
socket.try(data:send("0\r\n\r\n"))
end
local function send_body(data, source)
while true do
local chunk, err = source()
assert(chunk or not err, err)
if not chunk then break end
socket.try(data:send(chunk))
source = socket.source("until-closed", tmp.sock)
end
socket.try(pump(source, sink))
end
local function send_headers(sock, headers)
-- send request headers
for i, v in pairs(headers) do
socket.try(sock:send(i .. ": " .. v .. "\r\n"))
--io.write(i .. ": " .. v .. "\r\n")
end
-- mark end of request headers
socket.try(sock:send("\r\n"))
--io.write("\r\n")
end
local function should_receive_body(reqt, respt, tmp)
@ -211,22 +124,21 @@ end
local function send_request(reqt, respt, tmp)
local uri = request_uri(reqt, respt, tmp)
local sock = tmp.sock
local headers = tmp.headers
local pump = reqt.pump or ltn12.pump
-- send request line
socket.try(sock:send((reqt.method or "GET")
socket.try(tmp.sock:send((reqt.method or "GET")
.. " " .. uri .. " HTTP/1.1\r\n"))
--io.write((reqt.method or "GET")
--.. " " .. uri .. " HTTP/1.1\r\n")
-- send request headers headeres
if reqt.source and not headers["content-length"] then
headers["transfer-encoding"] = "chunked"
end
send_headers(sock, headers)
send_headers(tmp.sock, headers)
-- send request message body, if any
if reqt.source then
if headers["content-length"] then send_body(sock, reqt.source)
else send_body_bychunks(sock, reqt.source) end
if not reqt.source then return end
if headers["content-length"] then
socket.try(pump(reqt.source, socket.sink(tmp.sock)))
else
socket.try(pump(reqt.source, socket.sink("http-chunked", tmp.sock)))
end
end
@ -235,7 +147,7 @@ local function open(reqt, respt, tmp)
local host, port
if proxy then
local pproxy = socket.url.parse(proxy)
assert(pproxy.port and pproxy.host, "invalid proxy")
socket.try(pproxy.port and pproxy.host, "invalid proxy")
host, port = pproxy.host, pproxy.port
else
host, port = tmp.parsed.host, tmp.parsed.port
@ -271,9 +183,8 @@ local function parse_url(reqt, respt, tmp)
scheme = "http"
})
-- scheme has to be http
if parsed.scheme ~= "http" then
error(string.format("unknown scheme '%s'", parsed.scheme))
end
socket.try(parsed.scheme == "http",
string.format("unknown scheme '%s'", parsed.scheme))
-- explicit authentication info overrides that given by the URL
parsed.user = reqt.user or parsed.user
parsed.password = reqt.password or parsed.password
@ -342,6 +253,12 @@ local function redirect(reqt, respt, tmp)
if respt.headers then respt.headers.location = redirt.url end
end
local function skip_continue(reqt, respt, tmp)
if respt.code == 100 then
receive_status(reqt, respt, tmp)
end
end
-- execute a request of through an exception
function request_p(reqt, respt, tmp)
parse_url(reqt, respt, tmp)
@ -349,6 +266,7 @@ function request_p(reqt, respt, tmp)
open(reqt, respt, tmp)
send_request(reqt, respt, tmp)
receive_status(reqt, respt, tmp)
skip_continue(reqt, respt, tmp)
receive_headers(reqt, respt, tmp)
if should_redirect(reqt, respt, tmp) then
tmp.sock:close()

View File

@ -11,11 +11,14 @@ decodet = {}
wrapt = {}
-- creates a function that chooses a filter by name from a given table
local function choose(table)
return function(name, opt)
function choose(table)
return function(name, opt1, opt2)
if type(name) ~= "string" then
name, opt1, opt2 = "default", name, opt1
end
local f = table[name or "nil"]
if not f then error("unknown filter (" .. tostring(name) .. ")", 3)
else return f(opt) end
if not f then error("unknown key (" .. tostring(name) .. ")", 3)
else return f(opt1, opt2) end
end
end
@ -44,6 +47,7 @@ wrapt['text'] = function(length)
return ltn12.filter.cycle(wrp, length, length)
end
wrapt['base64'] = wrapt['text']
wrapt['default'] = wrapt['text']
wrapt['quoted-printable'] = function()
return ltn12.filter.cycle(qpwrp, 76, 76)
@ -52,15 +56,7 @@ end
-- function that choose the encoding, decoding or wrap algorithm
encode = choose(encodet)
decode = choose(decodet)
-- it's different because there is a default wrap filter
local cwt = choose(wrapt)
function wrap(mode_or_length, length)
if type(mode_or_length) ~= "string" then
length = mode_or_length
mode_or_length = "text"
end
return cwt(mode_or_length, length)
end
wrap = choose(wrapt)
-- define the end-of-line normalization filter
function normalize(marker)

View File

@ -411,6 +411,7 @@ local body
io.write("testing simple get function: ")
body = socket.http.get("http://" .. host .. prefix .. "/index.html")
assert(body == index)
print("ok")
------------------------------------------------------------------------
io.write("testing HEAD method: ")
@ -420,6 +421,7 @@ response = socket.http.request {
url = "http://www.cs.princeton.edu/~diego/"
}
assert(response and response.headers)
print("ok")
------------------------------------------------------------------------
print("passed all tests")