diff --git a/src/http.lua b/src/http.lua index 8f3fdb9..e0c4c27 100644 --- a/src/http.lua +++ b/src/http.lua @@ -143,117 +143,111 @@ local function adjustheaders(headers, host) return lower end +local default = { + host = "", + port = PORT, + path ="/", + scheme = "http" +} + local function adjustrequest(reqt) - -- parse url with default fields - local parsed = url.parse(reqt.url or "", { - host = "", - port = PORT, - path ="/", - scheme = "http" - }) - -- explicit info in reqt overrides that given by the URL - for i,v in reqt do parsed[i] = v end + -- parse url if provided + if reqt.url then + local parsed = url.parse(reqt.url, default) + -- explicit components override url + for i,v in parsed do reqt[i] = reqt[i] or v end + end + socket.try(reqt.host, "invalid host '" .. tostring(reqt.host) .. "'") + socket.try(reqt.path, "invalid path '" .. tostring(reqt.path) .. "'") -- compute uri if user hasn't overriden - parsed.uri = parsed.uri or uri(parsed) + reqt.uri = reqt.uri or uri(reqt) -- adjust headers in request - parsed.headers = adjustheaders(parsed.headers, parsed.host) - return parsed + reqt.headers = adjustheaders(reqt.headers, reqt.host) + return reqt end -local function shouldredirect(reqt, respt) +local function shouldredirect(reqt, code) return (reqt.redirect ~= false) and - (respt.code == 301 or respt.code == 302) and + (code == 301 or code == 302) and (not reqt.method or reqt.method == "GET" or reqt.method == "HEAD") and (not reqt.nredirects or reqt.nredirects < 5) end -local function shouldauthorize(reqt, respt) +local function shouldauthorize(reqt, code) -- if there has been an authorization attempt, it must have failed if reqt.headers and reqt.headers["authorization"] then return nil end -- if last attempt didn't fail due to lack of authentication, -- or we don't have authorization information, we can't retry - return respt.code == 401 and reqt.user and reqt.password + return code == 401 and reqt.user and reqt.password end -local function shouldreceivebody(reqt, respt) +local function shouldreceivebody(reqt, code) if reqt.method == "HEAD" then return nil end - local code = respt.code if code == 204 or code == 304 then return nil end if code >= 100 and code < 200 then return nil end return 1 end -local requestp, authorizep, redirectp +-- forward declarations +local trequest, tauthorize, tredirect -function requestp(reqt) - local reqt = adjustrequest(reqt) - local respt = {} - local con = open(reqt.host, reqt.port) - con:sendrequestline(reqt.method, reqt.uri) - con:sendheaders(reqt.headers) - con:sendbody(reqt.headers, reqt.source, reqt.step) - respt.code, respt.status = con:receivestatusline() - respt.headers = con:receiveheaders() - if shouldredirect(reqt, respt) then - con:close() - return redirectp(reqt, respt) - elseif shouldauthorize(reqt, respt) then - con:close() - return authorizep(reqt, respt) - elseif shouldreceivebody(reqt, respt) then - con:receivebody(respt.headers, reqt.sink, reqt.step) - end - con:close() - return respt -end - -function authorizep(reqt, respt) +function tauthorize(reqt) local auth = "Basic " .. (mime.b64(reqt.user .. ":" .. reqt.password)) reqt.headers["authorization"] = auth - return requestp(reqt) + return trequest(reqt) end -function redirectp(reqt, respt) - -- we create a new table to get rid of anything we don't - -- absolutely need, including authentication info - local redirt = { - method = reqt.method, - -- the RFC says the redirect URL has to be absolute, but some - -- servers do not respect that - url = url.absolute(reqt.url, respt.headers["location"]), +function tredirect(reqt, headers) + -- the RFC says the redirect URL has to be absolute, but some + -- servers do not respect that + return trequest { + url = url.absolute(reqt, headers["location"]), source = reqt.source, sink = reqt.sink, headers = reqt.headers, proxy = reqt.proxy, nredirects = (reqt.nredirects or 0) + 1 } - respt = requestp(redirt) - -- we pass the location header as a clue we redirected - if respt.headers then respt.headers.location = redirt.url end - return respt end -request = socket.protect(requestp) +function trequest(reqt) + reqt = adjustrequest(reqt) + local con = open(reqt.host, reqt.port) + con:sendrequestline(reqt.method, reqt.uri) + con:sendheaders(reqt.headers) + con:sendbody(reqt.headers, reqt.source, reqt.step) + local code, headers, status + code, status = con:receivestatusline() + headers = con:receiveheaders() + if shouldredirect(reqt, code) then + con:close() + return tredirect(reqt, headers) + elseif shouldauthorize(reqt, code) then + con:close() + return tauthorize(reqt) + elseif shouldreceivebody(reqt, code) then + con:receivebody(headers, reqt.sink, reqt.step) + end + con:close() + return 1, code, headers, status +end -get = socket.protect(function(u) +local function srequest(u, body) local t = {} - local respt = requestp { + local reqt = { url = u, sink = ltn12.sink.table(t) } - return (table.getn(t) > 0 or nil) and table.concat(t), respt.headers, - respt.code -end) + if body then + reqt.source = ltn12.source.string(body) + reqt.headers = { ["content-length"] = string.len(body) } + reqt.method = "POST" + end + local code, headers, status = socket.skip(1, trequest(reqt)) + return table.concat(t), code, headers, status +end -post = socket.protect(function(u, body) - local t = {} - local respt = requestp { - url = u, - method = "POST", - source = ltn12.source.string(body), - sink = ltn12.sink.table(t), - headers = { ["content-length"] = string.len(body) } - } - return (table.getn(t) > 0 or nil) and table.concat(t), - respt.headers, respt.code +request = socket.protect(function(reqt, body) + if type(reqt) == "string" then return srequest(reqt, body) + else return trequest(reqt) end end) diff --git a/src/url.lua b/src/url.lua index 960a248..ec26e62 100644 --- a/src/url.lua +++ b/src/url.lua @@ -190,7 +190,7 @@ end -- corresponding absolute url ----------------------------------------------------------------------------- function absolute(base_url, relative_url) - local base = parse(base_url) + local base = type(base_url) == "table" and base_url or parse(base_url) local relative = parse(relative_url) if not base then return relative_url elseif not relative then return base_url diff --git a/test/httptest.lua b/test/httptest.lua index 61dc60a..a171dd9 100644 --- a/test/httptest.lua +++ b/test/httptest.lua @@ -55,12 +55,12 @@ end local check_request = function(request, expect, ignore) local t - if not request.sink then - request.sink, t = ltn12.sink.table(t) - end + if not request.sink then request.sink, t = ltn12.sink.table() end request.source = request.source or (request.body and ltn12.source.string(request.body)) - local response = http.request(request) + local response = {} + response.code, response.headers, response.status = + socket.skip(1, http.request(request)) if t and table.getn(t) > 0 then response.body = table.concat(t) end check_result(response, expect, ignore) end @@ -68,8 +68,8 @@ end ------------------------------------------------------------------------ io.write("testing request uri correctness: ") local forth = cgiprefix .. "/request-uri?" .. "this+is+the+query+string" -local back, h, c, e = http.get("http://" .. host .. forth) -if not back then fail(e) end +local back, c, h = http.request("http://" .. host .. forth) +if not back then fail(c) end back = url.parse(back) if similar(back.query, "this+is+the+query+string") then print("ok") else fail(back.query) end @@ -77,7 +77,7 @@ else fail(back.query) end ------------------------------------------------------------------------ io.write("testing query string correctness: ") forth = "this+is+the+query+string" -back = http.get("http://" .. host .. cgiprefix .. +back = http.request("http://" .. host .. cgiprefix .. "/query-string?" .. forth) if similar(back, forth) then print("ok") else fail("failed!") end @@ -153,7 +153,7 @@ check_request(request, expect, ignore) ------------------------------------------------------------------------ io.write("testing simple post function: ") -back = http.post("http://" .. host .. cgiprefix .. "/cat", index) +back = http.request("http://" .. host .. cgiprefix .. "/cat", index) assert(back == index) ------------------------------------------------------------------------ @@ -378,19 +378,19 @@ check_request(request, expect, ignore) ------------------------------------------------------------------------ local body -io.write("testing simple get function: ") -body = http.get("http://" .. host .. prefix .. "/index.html") +io.write("testing simple request function: ") +body = http.request("http://" .. host .. prefix .. "/index.html") assert(body == index) print("ok") ------------------------------------------------------------------------ io.write("testing HEAD method: ") http.TIMEOUT = 1 -response = http.request { +local r, c, h = http.request { method = "HEAD", url = "http://www.cs.princeton.edu/~diego/" } -assert(response and response.headers) +assert(r and h and c == 200) print("ok") ------------------------------------------------------------------------ @@ -398,7 +398,7 @@ io.write("testing host not found: ") local c, e = socket.connect("wronghost", 80) local r, re = http.request{url = "http://wronghost/does/not/exist"} assert(r == nil and e == re) -r, re = http.get("http://wronghost/does/not/exist") +r, re = http.request("http://wronghost/does/not/exist") assert(r == nil and e == re) print("ok") @@ -407,7 +407,7 @@ io.write("testing invalid url: ") local c, e = socket.connect("", 80) local r, re = http.request{url = host .. prefix} assert(r == nil and e == re) -r, re = http.get(host .. prefix) +r, re = http.request(host .. prefix) assert(r == nil and e == re) print("ok")