diff --git a/src/http.lua b/src/http.lua index 19b4dd3..2fa5a26 100644 --- a/src/http.lua +++ b/src/http.lua @@ -26,15 +26,24 @@ _M.TIMEOUT = 60 -- user agent field sent in request _M.USERAGENT = socket._VERSION --- supported schemes +-- supported schemes and their particulars local SCHEMES = { - http = { port = 80 } - , https = { port = 443 }} + http = { + port = 80 + , create = function(t) + return socket.tcp end } + , https = { + port = 443 + , create = function(t) + local https = assert( + require("ssl.https"), 'LuaSocket: LuaSec not found') + local tcp = assert( + https.tcp, 'LuaSocket: Function tcp() not available from LuaSec') + return tcp(t) end }} -- default scheme and port for document retrieval local SCHEME = 'http' local PORT = SCHEMES[SCHEME].port - ----------------------------------------------------------------------------- -- Reads MIME headers from a connection, unfolding where needed ----------------------------------------------------------------------------- @@ -115,13 +124,13 @@ local metat = { __index = {} } function _M.open(host, port, create) -- create socket with user connect function, or with default - local c = socket.try((create or socket.tcp)()) + local c = socket.try(create()) local h = base.setmetatable({ c = c }, metat) -- create finalized try h.try = socket.newtry(function() h:close() end) -- set timeout before connecting h.try(c:settimeout(_M.TIMEOUT)) - h.try(c:connect(host, port or PORT)) + h.try(c:connect(host, port)) -- here everything worked return h end @@ -221,14 +230,13 @@ end local function adjustheaders(reqt) -- default headers - local headhost = reqt.host - local headport = tostring(reqt.port) - local schemeport = tostring(SCHEMES[reqt.scheme].port) - if headport ~= schemeport then - headhost = headhost .. ':' .. headport end + local host = reqt.host + local port = tostring(reqt.port) + if port ~= tostring(SCHEMES[reqt.scheme].port) then + host = host .. ':' .. port end local lower = { ["user-agent"] = _M.USERAGENT, - ["host"] = headhost, + ["host"] = host, ["connection"] = "close, TE", ["te"] = "trailers" } @@ -267,8 +275,13 @@ local function adjustrequest(reqt) local nreqt = reqt.url and url.parse(reqt.url, default) or {} -- explicit components override url for i,v in base.pairs(reqt) do nreqt[i] = v end - if nreqt.port == "" then nreqt.port = PORT end - if not (nreqt.host and nreqt.host ~= "") then + -- default to scheme particulars + local schemedefs, host, port, method + = SCHEMES[nreqt.scheme], nreqt.host, nreqt.port, nreqt.method + if not nreqt.create then nreqt.create = schemedefs.create(nreqt) end + if not (port and port ~= '') then nreqt.port = schemedefs.port end + if not (method and method ~= '') then nreqt.method = 'GET' end + if not (host and host ~= "") then socket.try(nil, "invalid host '" .. base.tostring(nreqt.host) .. "'") end -- compute uri if user hasn't overriden @@ -285,8 +298,10 @@ local function shouldredirect(reqt, code, headers) if not location then return false end location = string.gsub(location, "%s", "") if location == "" then return false end - local scheme = string.match(location, "^([%w][%w%+%-%.]*)%:") - if scheme and not SCHEMES[scheme] then return false end + local scheme = url.parse(location).scheme + if scheme and (not SCHEMES[scheme]) then return false end + -- avoid https downgrades + if ('https' == reqt.scheme) and ('https' ~= scheme) then return false end return (reqt.redirect ~= false) and (code == 301 or code == 302 or code == 303 or code == 307) and (not reqt.method or reqt.method == "GET" or reqt.method == "HEAD") @@ -306,10 +321,16 @@ end local trequest, tredirect --[[local]] function tredirect(reqt, location) + -- the RFC says the redirect URL has to be absolute, but some + -- servers do not respect that + local newurl = url.absolute(reqt.url, location) + -- if switching schemes, reset port and create function + if url.parse(newurl).scheme ~= reqt.scheme then + reqt.port = nil + reqt.create = nil end + -- make new request local result, code, headers, status = trequest { - -- the RFC says the redirect URL has to be absolute, but some - -- servers do not respect that - url = url.absolute(reqt.url, location), + url = newurl, source = reqt.source, sink = reqt.sink, headers = reqt.headers, @@ -397,4 +418,5 @@ _M.request = socket.protect(function(reqt, body) else return trequest(reqt) end end) +_M.schemes = SCHEMES return _M