Settimeout wasn't returning 1...

http.lua is even cleaner. No trash in respt table.
This commit is contained in:
Diego Nehab 2004-03-22 07:45:07 +00:00
parent 1fa65d89ca
commit 5b279bac9e
3 changed files with 64 additions and 63 deletions

2
TODO
View File

@ -19,6 +19,8 @@
* Separar as classes em arquivos * Separar as classes em arquivos
* Retorno de sendto em datagram sockets pode ser refused * Retorno de sendto em datagram sockets pode ser refused
break smtp.send into c = smtp.open, c:send() c:close()
falar sobre encodet/wrapt/decodet no manual sobre mime falar sobre encodet/wrapt/decodet no manual sobre mime

View File

@ -39,10 +39,10 @@ local function third(a, b, c)
return c return c
end end
local function receive_headers(reqt, respt) local function receive_headers(reqt, respt, tmp)
local headers = {} local sock = tmp.sock
local sock = respt.tmp.sock
local line, name, value, _ local line, name, value, _
local headers = {}
-- store results -- store results
respt.headers = headers respt.headers = headers
-- get first line -- get first line
@ -132,10 +132,10 @@ local function receive_body_untilclosed(sock, sink)
hand(sink, nil) hand(sink, nil)
end end
local function receive_body(reqt, respt) local function receive_body(reqt, respt, tmp)
local sink = reqt.sink or ltn12.sink.null() local sink = reqt.sink or ltn12.sink.null()
local headers = respt.headers local headers = respt.headers
local sock = respt.tmp.sock local sock = tmp.sock
local te = headers["transfer-encoding"] local te = headers["transfer-encoding"]
if te and te ~= "identity" then if te and te ~= "identity" then
-- get by chunked transfer-coding of message body -- get by chunked transfer-coding of message body
@ -174,47 +174,50 @@ local function send_headers(sock, headers)
-- send request headers -- send request headers
for i, v in pairs(headers) do for i, v in pairs(headers) do
socket.try(sock:send(i .. ": " .. v .. "\r\n")) socket.try(sock:send(i .. ": " .. v .. "\r\n"))
--io.write(i .. ": " .. v .. "\r\n")
end end
-- mark end of request headers -- mark end of request headers
socket.try(sock:send("\r\n")) socket.try(sock:send("\r\n"))
--io.write("\r\n")
end end
local function should_receive_body(reqt, respt) local function should_receive_body(reqt, respt, tmp)
if reqt.method == "HEAD" then return nil end if reqt.method == "HEAD" then return nil end
if respt.code == 204 or respt.code == 304 then return nil end if respt.code == 204 or respt.code == 304 then return nil end
if respt.code >= 100 and respt.code < 200 then return nil end if respt.code >= 100 and respt.code < 200 then return nil end
return 1 return 1
end end
local function receive_status(reqt, respt) local function receive_status(reqt, respt, tmp)
local sock = respt.tmp.sock local status = socket.try(tmp.sock:receive())
local status = socket.try(sock:receive())
local code = third(string.find(status, "HTTP/%d*%.%d* (%d%d%d)")) local code = third(string.find(status, "HTTP/%d*%.%d* (%d%d%d)"))
-- store results -- store results
respt.code, respt.status = tonumber(code), status respt.code, respt.status = tonumber(code), status
end end
local function request_uri(reqt, respt) local function request_uri(reqt, respt, tmp)
local url local url = tmp.parsed
local parsed = respt.tmp.parsed
if not reqt.proxy then if not reqt.proxy then
local parsed = tmp.parsed
url = { url = {
path = parsed.path, path = parsed.path,
params = parsed.params, params = parsed.params,
query = parsed.query, query = parsed.query,
fragment = parsed.fragment fragment = parsed.fragment
} }
else url = respt.tmp.parsed end end
return socket.url.build(url) return socket.url.build(url)
end end
local function send_request(reqt, respt) local function send_request(reqt, respt, tmp)
local uri = request_uri(reqt, respt) local uri = request_uri(reqt, respt, tmp)
local sock = respt.tmp.sock local sock = tmp.sock
local headers = respt.tmp.headers local headers = tmp.headers
-- send request line -- send request line
socket.try(sock:send((reqt.method or "GET") socket.try(sock:send((reqt.method or "GET")
.. " " .. uri .. " HTTP/1.1\r\n")) .. " " .. uri .. " HTTP/1.1\r\n"))
--io.write((reqt.method or "GET")
--.. " " .. uri .. " HTTP/1.1\r\n")
-- send request headers headeres -- send request headers headeres
if reqt.source and not headers["content-length"] then if reqt.source and not headers["content-length"] then
headers["transfer-encoding"] = "chunked" headers["transfer-encoding"] = "chunked"
@ -227,8 +230,7 @@ local function send_request(reqt, respt)
end end
end end
local function open(reqt, respt) local function open(reqt, respt, tmp)
local parsed = respt.tmp.parsed
local proxy = reqt.proxy or PROXY local proxy = reqt.proxy or PROXY
local host, port local host, port
if proxy then if proxy then
@ -236,16 +238,15 @@ local function open(reqt, respt)
assert(pproxy.port and pproxy.host, "invalid proxy") assert(pproxy.port and pproxy.host, "invalid proxy")
host, port = pproxy.host, pproxy.port host, port = pproxy.host, pproxy.port
else else
host, port = parsed.host, parsed.port host, port = tmp.parsed.host, tmp.parsed.port
end end
local sock = socket.try(socket.tcp())
-- store results -- store results
respt.tmp.sock = sock tmp.sock = socket.try(socket.tcp())
sock:settimeout(reqt.timeout or TIMEOUT) socket.try(tmp.sock:settimeout(reqt.timeout or TIMEOUT))
socket.try(sock:connect(host, port)) socket.try(tmp.sock:connect(host, port))
end end
local function adjust_headers(reqt, respt) local function adjust_headers(reqt, respt, tmp)
local lower = {} local lower = {}
local headers = reqt.headers or {} local headers = reqt.headers or {}
-- set default headers -- set default headers
@ -254,14 +255,14 @@ local function adjust_headers(reqt, respt)
for i,v in headers do for i,v in headers do
lower[string.lower(i)] = v lower[string.lower(i)] = v
end end
lower["host"] = respt.tmp.parsed.host lower["host"] = tmp.parsed.host
-- this cannot be overriden -- this cannot be overriden
lower["connection"] = "close" lower["connection"] = "close"
-- store results -- store results
respt.tmp.headers = lower tmp.headers = lower
end end
local function parse_url(reqt, respt) local function parse_url(reqt, respt, tmp)
-- parse url with default fields -- parse url with default fields
local parsed = socket.url.parse(reqt.url, { local parsed = socket.url.parse(reqt.url, {
host = "", host = "",
@ -277,19 +278,18 @@ local function parse_url(reqt, respt)
parsed.user = reqt.user or parsed.user parsed.user = reqt.user or parsed.user
parsed.password = reqt.password or parsed.password parsed.password = reqt.password or parsed.password
-- store results -- store results
respt.tmp.parsed = parsed tmp.parsed = parsed
end end
-- forward declaration -- forward declaration
local request_p local request_p
local function should_authorize(reqt, respt) local function should_authorize(reqt, respt, tmp)
-- if there has been an authorization attempt, it must have failed -- if there has been an authorization attempt, it must have failed
if reqt.headers and reqt.headers["authorization"] then return nil end if reqt.headers and reqt.headers["authorization"] then return nil end
-- if last attempt didn't fail due to lack of authentication, -- if last attempt didn't fail due to lack of authentication,
-- or we don't have authorization information, we can't retry -- or we don't have authorization information, we can't retry
return respt.code == 401 and return respt.code == 401 and tmp.parsed.user and tmp.parsed.password
respt.tmp.parsed.user and respt.tmp.parsed.password
end end
local function clone(headers) local function clone(headers)
@ -301,11 +301,10 @@ local function clone(headers)
return copy return copy
end end
local function authorize(reqt, respt) local function authorize(reqt, respt, tmp)
local headers = clone(reqt.headers) or {} local headers = clone(reqt.headers) or {}
local parsed = respt.tmp.parsed
headers["authorization"] = "Basic " .. headers["authorization"] = "Basic " ..
(mime.b64(parsed.user .. ":" .. parsed.password)) (mime.b64(tmp.parsed.user .. ":" .. tmp.parsed.password))
local autht = { local autht = {
method = reqt.method, method = reqt.method,
url = reqt.url, url = reqt.url,
@ -315,18 +314,18 @@ local function authorize(reqt, respt)
timeout = reqt.timeout, timeout = reqt.timeout,
proxy = reqt.proxy, proxy = reqt.proxy,
} }
request_p(autht, respt) request_p(autht, respt, tmp)
end end
local function should_redirect(reqt, respt) local function should_redirect(reqt, respt, tmp)
return (reqt.redirect ~= false) and return (reqt.redirect ~= false) and
(respt.code == 301 or respt.code == 302) and (respt.code == 301 or respt.code == 302) and
(not reqt.method or reqt.method == "GET" or reqt.method == "HEAD") (not reqt.method or reqt.method == "GET" or reqt.method == "HEAD")
and (not respt.tmp.nredirects or respt.tmp.nredirects < 5) and (not tmp.nredirects or tmp.nredirects < 5)
end end
local function redirect(reqt, respt) local function redirect(reqt, respt, tmp)
respt.tmp.nredirects = (respt.tmp.nredirects or 0) + 1 tmp.nredirects = (tmp.nredirects or 0) + 1
local redirt = { local redirt = {
method = reqt.method, method = reqt.method,
-- the RFC says the redirect URL has to be absolute, but some -- the RFC says the redirect URL has to be absolute, but some
@ -338,36 +337,35 @@ local function redirect(reqt, respt)
timeout = reqt.timeout, timeout = reqt.timeout,
proxy = reqt.proxy proxy = reqt.proxy
} }
request_p(redirt, respt) request_p(redirt, respt, tmp)
-- we pass the location header as a clue we redirected -- we pass the location header as a clue we redirected
if respt.headers then respt.headers.location = redirt.url end if respt.headers then respt.headers.location = redirt.url end
end end
-- execute a request of through an exception -- execute a request of through an exception
function request_p(reqt, respt) function request_p(reqt, respt, tmp)
parse_url(reqt, respt) parse_url(reqt, respt, tmp)
adjust_headers(reqt, respt) adjust_headers(reqt, respt, tmp)
open(reqt, respt) open(reqt, respt, tmp)
send_request(reqt, respt) send_request(reqt, respt, tmp)
receive_status(reqt, respt) receive_status(reqt, respt, tmp)
receive_headers(reqt, respt) receive_headers(reqt, respt, tmp)
if should_redirect(reqt, respt) then if should_redirect(reqt, respt, tmp) then
respt.tmp.sock:close() tmp.sock:close()
redirect(reqt, respt) redirect(reqt, respt, tmp)
elseif should_authorize(reqt, respt) then elseif should_authorize(reqt, respt, tmp) then
respt.tmp.sock:close() tmp.sock:close()
authorize(reqt, respt) authorize(reqt, respt, tmp)
elseif should_receive_body(reqt, respt) then elseif should_receive_body(reqt, respt, tmp) then
receive_body(reqt, respt) receive_body(reqt, respt, tmp)
end end
end end
function request(reqt) function request(reqt)
local respt = { tmp = {} } local respt, tmp = {}, {}
local s, e = pcall(request_p, reqt, respt) local s, e = pcall(request_p, reqt, respt, tmp)
if not s then respt.error = e end if not s then respt.error = e end
if respt.tmp.sock then respt.tmp.sock:close() end if tmp.sock then tmp.sock:close() end
respt.tmp = nil
return respt return respt
end end
@ -377,7 +375,7 @@ function get(url)
url = url, url = url,
sink = ltn12.sink.table(t) sink = ltn12.sink.table(t)
} }
return table.getn(t) > 0 and table.concat(t), respt.headers, return (table.getn(t) > 0 or nil) and table.concat(t), respt.headers,
respt.code, respt.error respt.code, respt.error
end end
@ -390,6 +388,6 @@ function post(url, body)
sink = ltn12.sink.table(t), sink = ltn12.sink.table(t),
headers = { ["content-length"] = string.len(body) } headers = { ["content-length"] = string.len(body) }
} }
return table.getn(t) > 0 and table.concat(t), return (table.getn(t) > 0 or nil) and table.concat(t),
respt.headers, respt.code, respt.error respt.headers, respt.code, respt.error
end end

View File

@ -169,7 +169,8 @@ int tm_meth_settimeout(lua_State *L, p_tm tm)
luaL_argcheck(L, 0, 3, "invalid timeout mode"); luaL_argcheck(L, 0, 3, "invalid timeout mode");
break; break;
} }
return 0; lua_pushnumber(L, 1);
return 1;
} }
/*=========================================================================*\ /*=========================================================================*\