diff --git a/TODO b/TODO index 9fdb1e8..bd8a950 100644 --- a/TODO +++ b/TODO @@ -1,7 +1,17 @@ -replace times by getrusage +change send/recv to avoid using select + +add gethostname and use it in HTTP, SMTP etc, and add manual entry. +add local connect, and manual entry +add shutdown, and manual entry + +only allocate in case of success +only call select if io fails... +Proxy support pro http + +make REUSEADDR an option... make sure modules know if their dependencies are there. - +_ one thing i noticed in usocket.c is that it doesn't check for EINTR after write(), sendto(), read(), recvfrom() etc. ? the usual trick is to loop while you get EINTR: @@ -68,7 +78,6 @@ Ajeitar o protocolo da luaopen_socket()... sei l - proteger ou atomizar o conjunto (timedout, receive), (timedout, send) - inet_ntoa também é uma merda. - SSL -- Proxy support pro http - checar operações em closed sockets - checar teste de writable socket com select diff --git a/etc/get.lua b/etc/get.lua index a093e24..2d804a0 100644 --- a/etc/get.lua +++ b/etc/get.lua @@ -99,7 +99,7 @@ end function getbyhttp(url, file, size) local response = socket.http.request_cb( {url = url}, - {body_cb = receive2disk(file, size)} + {body_cb = receive2disk(file, size)} ) print() if response.code ~= 200 then print(response.status or response.error) end diff --git a/src/http.lua b/src/http.lua index 18a44b6..1925c68 100644 --- a/src/http.lua +++ b/src/http.lua @@ -5,22 +5,29 @@ -- Conforming to: RFC 2616, LTN7 -- RCS ID: $Id$ ----------------------------------------------------------------------------- - -local Public, Private = {}, {} -local socket = _G[LUASOCKET_LIBNAME] -- get LuaSocket namespace -socket.http = Public -- create http sub namespace +-- make sure LuaSocket is loaded +if not LUASOCKET_LIBNAME then error('module requires LuaSocket') end +-- get LuaSocket namespace +local socket = _G[LUASOCKET_LIBNAME] +if not socket then error('module requires LuaSocket') end +-- create smtp namespace inside LuaSocket namespace +local http = {} +socket.http = http +-- make all module globals fall into smtp namespace +setmetatable(http, { __index = _G }) +setfenv(1, http) ----------------------------------------------------------------------------- -- Program constants ----------------------------------------------------------------------------- -- connection timeout in seconds -Public.TIMEOUT = 60 +TIMEOUT = 60 -- default port for document retrieval -Public.PORT = 80 +PORT = 80 -- user agent field sent in request -Public.USERAGENT = "LuaSocket 2.0" +USERAGENT = "LuaSocket 2.0" -- block size used in transfers -Public.BLOCKSIZE = 8192 +BLOCKSIZE = 8192 ----------------------------------------------------------------------------- -- Tries to get a pattern from the server and closes socket on error @@ -30,7 +37,7 @@ Public.BLOCKSIZE = 8192 -- received pattern on success -- nil followed by error message on error ----------------------------------------------------------------------------- -function Private.try_receive(sock, pattern) +local function try_receiving(sock, pattern) local data, err = sock:receive(pattern) if not data then sock:close() end return data, err @@ -43,25 +50,12 @@ end -- Returns -- err: error message if any, nil if successfull ----------------------------------------------------------------------------- -function Private.try_send(sock, data) - local sent, err = sock:send(data) +local function try_sending(sock, ...) + local sent, err = sock:send(unpack(arg)) if not sent then sock:close() end return err end ------------------------------------------------------------------------------ --- Computes status code from HTTP status line --- Input --- line: HTTP status line --- Returns --- code: integer with status code, or nil if malformed line ------------------------------------------------------------------------------ -function Private.get_statuscode(line) - local code, _ - _, _, code = string.find(line, "HTTP/%d*%.%d* (%d%d%d)") - return tonumber(code) -end - ----------------------------------------------------------------------------- -- Receive server reply messages, parsing for status code -- Input @@ -71,10 +65,13 @@ end -- line: full HTTP status line -- err: error message if any ----------------------------------------------------------------------------- -function Private.receive_status(sock) +local function receive_status(sock) local line, err - line, err = Private.try_receive(sock) - if not err then return Private.get_statuscode(line), line + line, err = try_receiving(sock) + if not err then + local code, _ + _, _, code = string.find(line, "HTTP/%d*%.%d* (%d%d%d)") + return tonumber(code), line else return nil, nil, err end end @@ -89,11 +86,12 @@ end -- all name_i are lowercase -- nil and error message in case of error ----------------------------------------------------------------------------- -function Private.receive_headers(sock, headers) +local function receive_headers(sock, headers) local line, err local name, value, _ + headers = headers or {} -- get first line - line, err = Private.try_receive(sock) + line, err = try_receiving(sock) if err then return nil, err end -- headers go until a blank line is found while line ~= "" do @@ -105,12 +103,12 @@ function Private.receive_headers(sock, headers) end name = string.lower(name) -- get next line (value might be folded) - line, err = Private.try_receive(sock) + line, err = try_receiving(sock) if err then return nil, err end -- unfold any folded values while not err and string.find(line, "^%s") do value = value .. line - line, err = Private.try_receive(sock) + line, err = try_receiving(sock) if err then return nil, err end end -- save pair in table @@ -120,6 +118,19 @@ function Private.receive_headers(sock, headers) return headers end +----------------------------------------------------------------------------- +-- Aborts a receive callback +-- Input +-- cb: callback function +-- err: error message to pass to callback +-- Returns +-- callback return or if nil err +----------------------------------------------------------------------------- +local function abort(cb, err) + local go, err_or_f = cb(nil, err) + return err_or_f or err +end + ----------------------------------------------------------------------------- -- Receives a chunked message body -- Input @@ -129,54 +140,37 @@ end -- Returns -- nil if successfull or an error message in case of error ----------------------------------------------------------------------------- -function Private.receivebody_bychunks(sock, headers, receive_cb) - local chunk, size, line, err, go, uerr, _ +local function receive_body_bychunks(sock, headers, receive_cb) + local chunk, size, line, err, go, err_or_f, _ while 1 do -- get chunk size, skip extention - line, err = Private.try_receive(sock) - if err then - local go, uerr = receive_cb(nil, err) - return uerr or err - end + line, err = try_receiving(sock) + if err then return abort(receive_cb, err) end size = tonumber(string.gsub(line, ";.*", ""), 16) - if not size then - err = "invalid chunk size" - sock:close() - go, uerr = receive_cb(nil, err) - return uerr or err - end + if not size then return abort(receive_cb, "invalid chunk size") end -- was it the last chunk? if size <= 0 then break end -- get chunk - chunk, err = Private.try_receive(sock, size) - if err then - go, uerr = receive_cb(nil, err) - return uerr or err - end + chunk, err = try_receiving(sock, size) + if err then return abort(receive_cb, err) end -- pass chunk to callback - go, uerr = receive_cb(chunk) - if not go then - sock:close() - return uerr or "aborted by callback" - end + go, err_or_f = receive_cb(chunk) + -- see if callback needs to be replaced + receive_cb = err_or_f or receive_cb + -- see if callback aborted + if not go then return err_or_f or "aborted by callback" end -- skip CRLF on end of chunk - _, err = Private.try_receive(sock) - if err then - go, uerr = receive_cb(nil, err) - return uerr or err - end + _, err = try_receiving(sock) + if err then return abort(receive_cb, err) end end -- the server should not send trailer headers because we didn't send a -- header informing it we know how to deal with them. we do not risk -- being caught unprepaired. - headers, err = Private.receive_headers(sock, headers) - if err then - go, uerr = receive_cb(nil, err) - return uerr or err - end + _, err = receive_headers(sock, headers) + if err then return abort(receive_cb, err) end -- let callback know we are done - go, uerr = receive_cb("") - return uerr + _, err_or_f = receive_cb("") + return err_or_f end ----------------------------------------------------------------------------- @@ -188,25 +182,21 @@ end -- Returns -- nil if successfull or an error message in case of error ----------------------------------------------------------------------------- -function Private.receivebody_bylength(sock, length, receive_cb) - local uerr, go +local function receive_body_bylength(sock, length, receive_cb) while length > 0 do - local size = math.min(Public.BLOCKSIZE, length) + local size = math.min(BLOCKSIZE, length) local chunk, err = sock:receive(size) - -- if there was an error before we got all the data - if err and string.len(chunk) ~= length then - go, uerr = receive_cb(nil, err) - return uerr or err - end - go, uerr = receive_cb(chunk) - if not go then - sock:close() - return uerr or "aborted by callback" - end - length = length - size + local go, err_or_f = receive_cb(chunk) + length = length - string.len(chunk) + -- see if callback aborted + if not go then return err_or_f or "aborted by callback" end + -- see if callback needs to be replaced + receive_cb = err_or_f or receive_cb + -- see if there was an error + if err and length > 0 then return abort(receive_cb, err) end end - go, uerr = receive_cb("") - return uerr + local _, err_or_f = receive_cb("") + return err_or_f end ----------------------------------------------------------------------------- @@ -217,24 +207,24 @@ end -- Returns -- nil if successfull or an error message in case of error ----------------------------------------------------------------------------- -function Private.receivebody_untilclosed(sock, receive_cb) - local err, go, uerr +local function receive_body_untilclosed(sock, receive_cb) while 1 do - local chunk, err = sock:receive(Public.BLOCKSIZE) - if err == "closed" or not err then - go, uerr = receive_cb(chunk) - if not go then - sock:close() - return uerr or "aborted by callback" + local chunk, err = sock:receive(BLOCKSIZE) + local go, err_or_f = receive_cb(chunk) + -- see if callback aborted + if not go then return err_or_f or "aborted by callback" end + -- see if callback needs to be replaced + receive_cb = err_or_f or receive_cb + -- see if we are done + if err == "closed" then + if chunk ~= "" then + go, err_or_f = receive_cb("") + return err_or_f end - if err == "closed" then break end - else - go, uerr = callback(nil, err) - return uerr or err end + -- see if there was an error + if err then return abort(receive_cb, err) end end - go, uerr = receive_cb("") - return uerr end ----------------------------------------------------------------------------- @@ -246,59 +236,68 @@ end -- Returns -- nil if successfull or an error message in case of error ----------------------------------------------------------------------------- -function Private.receive_body(sock, headers, receive_cb) +local function receive_body(sock, headers, receive_cb) local te = headers["transfer-encoding"] if te and te ~= "identity" then -- get by chunked transfer-coding of message body - return Private.receivebody_bychunks(sock, headers, receive_cb) + return receive_body_bychunks(sock, headers, receive_cb) elseif tonumber(headers["content-length"]) then -- get by content-length local length = tonumber(headers["content-length"]) - return Private.receivebody_bylength(sock, length, receive_cb) + return receive_body_bylength(sock, length, receive_cb) else -- get it all until connection closes - return Private.receivebody_untilclosed(sock, receive_cb) + return receive_body_untilclosed(sock, receive_cb) end end ------------------------------------------------------------------------------ --- Drop HTTP response body --- Input --- sock: socket connected to the server --- headers: response header fields --- Returns --- nil if successfull or an error message in case of error ------------------------------------------------------------------------------ -function Private.drop_body(sock, headers) - return Private.receive_body(sock, headers, function (c, e) return 1 end) -end - ----------------------------------------------------------------------------- -- Sends data comming from a callback -- Input -- data: data connection -- send_cb: callback to produce file contents --- chunk, size: first callback return values -- Returns -- nil if successfull, or an error message in case of error ----------------------------------------------------------------------------- -function Private.send_indirect(data, send_cb, chunk, size) - local total, sent, err - total = 0 +local function send_body_bychunks(data, send_cb) while 1 do - if type(chunk) ~= "string" or type(size) ~= "number" then - data:close() - if not chunk and type(size) == "string" then return size - else return "invalid callback return" end - end - sent, err = data:send(chunk) - if err then - data:close() - return err - end - total = total + sent - if total >= size then break end - chunk, size = send_cb() + local chunk, err_or_f = send_cb() + -- check if callback aborted + if not chunk then return err_or_f or "aborted by callback" end + -- check if callback should be replaced + send_cb = err_or_f or send_cb + -- if we are done, send last-chunk + if chunk == "" then return try_sending(data, "0\r\n\r\n") end + -- else send middle chunk + local err = try_sending(data, + string.format("%X\r\n", string.len(chunk)), + chunk, + "\r\n" + ) + if err then return err end + end +end + +----------------------------------------------------------------------------- +-- Sends data comming from a callback +-- Input +-- data: data connection +-- send_cb: callback to produce body contents +-- Returns +-- nil if successfull, or an error message in case of error +----------------------------------------------------------------------------- +local function send_body_bylength(data, send_cb) + while 1 do + local chunk, err_or_f = send_cb() + -- check if callback aborted + if not chunk then return err_or_f or "aborted by callback" end + -- check if callback should be replaced + send_cb = err_or_f or send_cb + -- check if callback is done + if chunk == "" then return end + -- send data + local err = try_sending(data, chunk) + if err then return err end end end @@ -310,16 +309,16 @@ end -- Returns -- err: error message if any ----------------------------------------------------------------------------- -function Private.send_headers(sock, headers) +local function send_headers(sock, headers) local err headers = headers or {} -- send request headers for i, v in headers do - err = Private.try_send(sock, i .. ": " .. v .. "\r\n") + err = try_sending(sock, i .. ": " .. v .. "\r\n") if err then return err end end -- mark end of request headers - return Private.try_send(sock, "\r\n") + return try_sending(sock, "\r\n") end ----------------------------------------------------------------------------- @@ -333,43 +332,39 @@ end -- Returns -- err: nil in case of success, error message otherwise ----------------------------------------------------------------------------- -function Private.send_request(sock, method, uri, headers, body_cb) +local function send_request(sock, method, uri, headers, body_cb) local chunk, size, done, err -- send request line - err = Private.try_send(sock, method .. " " .. uri .. " HTTP/1.1\r\n") + err = try_sending(sock, method .. " " .. uri .. " HTTP/1.1\r\n") if err then return err end - -- if there is a request message body, add content-length header - chunk, size = body_cb() - if type(chunk) == "string" and type(size) == "number" then - if size > 0 then - headers["content-length"] = tostring(size) - end - else - sock:close() - if not chunk and type(size) == "string" then return size - else return "invalid callback return" end + if body_cb and not headers["content-length"] then + headers["transfer-encoding"] = "chunked" end -- send request headers - err = Private.send_headers(sock, headers) + err = send_headers(sock, headers) if err then return err end -- send request message body, if any if body_cb then - return Private.send_indirect(sock, body_cb, chunk, size) + if not headers["content-length"] then + return send_body_bychunks(sock, body_cb) + else + return send_body_bylength(sock, body_cb) + end end end ----------------------------------------------------------------------------- -- Determines if we should read a message body from the server response -- Input --- request: a table with the original request information --- response: a table with the server response information +-- reqt: a table with the original request information +-- respt: a table with the server response information -- Returns -- 1 if a message body should be processed, nil otherwise ----------------------------------------------------------------------------- -function Private.has_body(request, response) - if request.method == "HEAD" then return nil end - if response.code == 204 or response.code == 304 then return nil end - if response.code >= 100 and response.code < 200 then return nil end +local function should_receive_body(reqt, respt) + if reqt.method == "HEAD" then return nil end + if respt.code == 204 or respt.code == 304 then return nil end + if respt.code >= 100 and respt.code < 200 then return nil end return 1 end @@ -381,11 +376,11 @@ end -- Returns -- lower: a table with the same headers, but with lowercase field names ----------------------------------------------------------------------------- -function Private.fill_headers(headers, parsed) +local function fill_headers(headers, parsed) local lower = {} headers = headers or {} -- set default headers - lower["user-agent"] = Public.USERAGENT + lower["user-agent"] = USERAGENT -- override with user values for i,v in headers do lower[string.lower(i)] = v @@ -399,15 +394,15 @@ end ----------------------------------------------------------------------------- -- Decides wether we should follow retry with authorization formation -- Input --- request: a table with the original request information +-- reqt: a table with the original request information -- parsed: parsed request URL --- response: a table with the server response information +-- respt: a table with the server response information -- Returns -- 1 if we should retry, nil otherwise ----------------------------------------------------------------------------- -function Private.should_authorize(request, parsed, response) +local function should_authorize(reqt, parsed, respt) -- if there has been an authorization attempt, it must have failed - if request.headers["authorization"] then return nil end + if reqt.headers["authorization"] then return nil end -- if we don't have authorization information, we can't retry if parsed.user and parsed.password then return 1 else return nil end @@ -416,66 +411,66 @@ end ----------------------------------------------------------------------------- -- Returns the result of retrying a request with authorization information -- Input --- request: a table with the original request information +-- reqt: a table with the original request information -- parsed: parsed request URL --- response: a table with the server response information +-- respt: a table with the server response information -- Returns --- response: result of target redirection +-- respt: result of target authorization ----------------------------------------------------------------------------- -function Private.authorize(request, parsed, response) - request.headers["authorization"] = "Basic " .. - socket.code.base64(parsed.user .. ":" .. parsed.password) - local authorize = { - redirects = request.redirects, - method = request.method, - url = request.url, - body_cb = request.body_cb, - headers = request.headers +local function authorize(reqt, parsed, respt) + reqt.headers["authorization"] = "Basic " .. + socket.code.base64.encode(parsed.user .. ":" .. parsed.password) + local autht = { + nredirects = reqt.nredirects, + method = reqt.method, + url = reqt.url, + body_cb = reqt.body_cb, + headers = reqt.headers } - return Public.request_cb(authorize, response) + return request_cb(autht, respt) end ----------------------------------------------------------------------------- -- Decides wether we should follow a server redirect message -- Input --- request: a table with the original request information --- response: a table with the server response information +-- reqt: a table with the original request information +-- respt: a table with the server response information -- Returns -- 1 if we should redirect, nil otherwise ----------------------------------------------------------------------------- -function Private.should_redirect(request, response) - local follow = not request.stay - follow = follow and (response.code == 301 or response.code == 302) - follow = follow and (request.method == "GET" or request.method == "HEAD") - follow = follow and not (request.redirects and request.redirects >= 5) +local function should_redirect(reqt, respt) + local follow = not reqt.stay + follow = follow and (respt.code == 301 or respt.code == 302) + follow = follow and (reqt.method == "GET" or reqt.method == "HEAD") + follow = follow and not (reqt.nredirects and reqt.nredirects >= 5) return follow end ----------------------------------------------------------------------------- -- Returns the result of a request following a server redirect message. -- Input --- request: a table with the original request information --- response: a table with the following fields: +-- reqt: a table with the original request information +-- respt: a table with the following fields: -- body_cb: response method body receive-callback -- Returns --- response: result of target redirection +-- respt: result of target redirection ----------------------------------------------------------------------------- -function Private.redirect(request, response) - local redirects = request.redirects or 0 - redirects = redirects + 1 - local redirect = { - redirects = redirects, - method = request.method, +local function redirect(reqt, respt) + local nredirects = reqt.nredirects or 0 + nredirects = nredirects + 1 + local redirt = { + nredirects = nredirects, + method = reqt.method, -- the RFC says the redirect URL has to be absolute, but some -- servers do not respect that - url = socket.url.absolute(request.url, response.headers["location"]), - body_cb = request.body_cb, - headers = request.headers + url = socket.url.absolute(reqt.url, respt.headers["location"]), + body_cb = reqt.body_cb, + headers = reqt.headers } - local response = Public.request_cb(redirect, response) + respt = request_cb(redirt, respt) -- we pass the location header as a clue we tried to redirect - if response.headers then response.headers.location = redirect.url end - return response + if respt.headers then respt.headers.location = redirt.url end + return respt end ----------------------------------------------------------------------------- @@ -485,7 +480,7 @@ end -- Returns -- uri: request URI for parsed URL ----------------------------------------------------------------------------- -function Private.request_uri(parsed) +local function request_uri(parsed) local uri = "" if parsed.path then uri = uri .. parsed.path end if parsed.params then uri = uri .. ";" .. parsed.params end @@ -502,105 +497,110 @@ end -- user: account user name -- password: account password) -- Returns --- request: request table +-- reqt: request table ----------------------------------------------------------------------------- -function Private.build_request(data) - local request = {} +local function build_request(data) + local reqt = {} if type(data) == "table" then for i, v in data - do request[i] = v + do reqt[i] = v end - else request.url = data end - return request + else reqt.url = data end + return reqt end ----------------------------------------------------------------------------- -- Sends a HTTP request and retrieves the server reply using callbacks to -- send the request body and receive the response body -- Input --- request: a table with the following fields +-- reqt: a table with the following fields -- method: "GET", "PUT", "POST" etc (defaults to "GET") -- url: target uniform resource locator -- user, password: authentication information -- headers: request headers to send, or nil if none -- body_cb: request message body send-callback, or nil if none -- stay: should we refrain from following a server redirect message? --- response: a table with the following fields: +-- respt: a table with the following fields: -- body_cb: response method body receive-callback -- Returns --- response: a table with the following fields: +-- respt: a table with the following fields: -- headers: response header fields received, or nil if failed -- status: server response status line, or nil if failed -- code: server status code, or nil if failed -- error: error message, or nil if successfull ----------------------------------------------------------------------------- -function Public.request_cb(request, response) - local parsed = socket.url.parse(request.url, { +function request_cb(reqt, respt) + local parsed = socket.url.parse(reqt.url, { host = "", - port = Public.PORT, + port = PORT, path ="/", scheme = "http" }) if parsed.scheme ~= "http" then - response.error = string.format("unknown scheme '%s'", parsed.scheme) - return response + respt.error = string.format("unknown scheme '%s'", parsed.scheme) + return respt end -- explicit authentication info overrides that given by the URL - parsed.user = request.user or parsed.user - parsed.password = request.password or parsed.password + parsed.user = reqt.user or parsed.user + parsed.password = reqt.password or parsed.password -- default method - request.method = request.method or "GET" + reqt.method = reqt.method or "GET" -- fill default headers - request.headers = Private.fill_headers(request.headers, parsed) + reqt.headers = fill_headers(reqt.headers, parsed) -- try to connect to server local sock - sock, response.error = socket.connect(parsed.host, parsed.port) - if not sock then return response end + sock, respt.error = socket.connect(parsed.host, parsed.port) + if not sock then return respt end -- set connection timeout so that we do not hang forever - sock:settimeout(Public.TIMEOUT) + sock:settimeout(TIMEOUT) -- send request message - response.error = Private.send_request(sock, request.method, - Private.request_uri(parsed), request.headers, request.body_cb) - if response.error then return response end + respt.error = send_request(sock, reqt.method, + request_uri(parsed), reqt.headers, reqt.body_cb) + if respt.error then + sock:close() + return respt + end -- get server response message - response.code, response.status, response.error = - Private.receive_status(sock) - if response.error then return response end - -- deal with 1xx status - if response.code == 100 then - response.headers, response.error = Private.receive_headers(sock, {}) - if response.error then return response end - response.code, response.status, response.error = - Private.receive_status(sock) - if response.error then return response end + respt.code, respt.status, respt.error = receive_status(sock) + if respt.error then return respt end + -- deal with continue 100 + -- servers should not send them, but they might + if respt.code == 100 then + respt.headers, respt.error = receive_headers(sock, {}) + if respt.error then return respt end + respt.code, respt.status, respt.error = receive_status(sock) + if respt.error then return respt end end -- receive all headers - response.headers, response.error = Private.receive_headers(sock, {}) - if response.error then return response end + respt.headers, respt.error = receive_headers(sock, {}) + if respt.error then return respt end -- decide what to do based on request and response parameters - if Private.should_redirect(request, response) then - Private.drop_body(sock, response.headers) + if should_redirect(reqt, respt) then + -- drop the body + receive_body(sock, respt.headers, function (c, e) return 1 end) + -- we are done with this connection sock:close() - return Private.redirect(request, response) - elseif Private.should_authorize(request, parsed, response) then - Private.drop_body(sock, response.headers) + return redirect(reqt, respt) + elseif should_authorize(reqt, parsed, respt) then + -- drop the body + receive_body(sock, respt.headers, function (c, e) return 1 end) + -- we are done with this connection sock:close() - return Private.authorize(request, parsed, response) - elseif Private.has_body(request, response) then - response.error = Private.receive_body(sock, response.headers, - response.body_cb) - if response.error then return response end + return authorize(reqt, parsed, respt) + elseif should_receive_body(reqt, respt) then + respt.error = receive_body(sock, respt.headers, respt.body_cb) + if respt.error then return respt end sock:close() - return response + return respt end sock:close() - return response + return respt end ----------------------------------------------------------------------------- -- Sends a HTTP request and retrieves the server reply -- Input --- request: a table with the following fields +-- reqt: a table with the following fields -- method: "GET", "PUT", "POST" etc (defaults to "GET") -- url: request URL, i.e. the document to be retrieved -- user, password: authentication information @@ -608,22 +608,22 @@ end -- body: request message body as a string, or nil if none -- stay: should we refrain from following a server redirect message? -- Returns --- response: a table with the following fields: +-- respt: a table with the following fields: -- body: response message body, or nil if failed -- headers: response header fields, or nil if failed -- status: server response status line, or nil if failed -- code: server response status code, or nil if failed -- error: error message if any ----------------------------------------------------------------------------- -function Public.request(request) - local response = {} - request.body_cb = socket.callback.send_string(request.body) +function request(reqt) + local respt = {} + reqt.body_cb = socket.callback.send_string(reqt.body) local concat = socket.concat.create() - response.body_cb = socket.callback.receive_concat(concat) - response = Public.request_cb(request, response) - response.body = concat:getresult() - response.body_cb = nil - return response + respt.body_cb = socket.callback.receive_concat(concat) + respt = request_cb(reqt, respt) + respt.body = concat:getresult() + respt.body_cb = nil + return respt end ----------------------------------------------------------------------------- @@ -639,12 +639,11 @@ end -- code: server response status code, or nil if failed -- error: error message if any ----------------------------------------------------------------------------- -function Public.get(url_or_request) - local request = Private.build_request(url_or_request) - request.method = "GET" - local response = Public.request(request) - return response.body, response.headers, - response.code, response.error +function get(url_or_request) + local reqt = build_request(url_or_request) + reqt.method = "GET" + local respt = request(reqt) + return respt.body, respt.headers, respt.code, respt.error end ----------------------------------------------------------------------------- @@ -662,11 +661,14 @@ end -- code: server response status code, or nil if failed -- error: error message, or nil if successfull ----------------------------------------------------------------------------- -function Public.post(url_or_request, body) - local request = Private.build_request(url_or_request) - request.method = "POST" - request.body = request.body or body - local response = Public.request(request) - return response.body, response.headers, - response.code, response.error +function post(url_or_request, body) + local reqt = build_request(url_or_request) + reqt.method = "POST" + reqt.body = reqt.body or body + reqt.headers = reqt.headers or + { ["content-length"] = string.len(reqt.body) } + local respt = request(reqt) + return respt.body, respt.headers, respt.code, respt.error end + +return http diff --git a/src/socket.h b/src/socket.h index c7db5f2..cea9e0d 100644 --- a/src/socket.h +++ b/src/socket.h @@ -37,6 +37,7 @@ int sock_accept(p_sock ps, p_sock pa, SA *addr, socklen_t *addr_len, const char *sock_connect(p_sock ps, SA *addr, socklen_t addr_len); const char *sock_bind(p_sock ps, SA *addr, socklen_t addr_len); void sock_listen(p_sock ps, int backlog); +void sock_shutdown(p_sock ps, int how); int sock_send(p_sock ps, const char *data, size_t count, size_t *sent, int timeout); int sock_recv(p_sock ps, char *data, size_t count, diff --git a/src/tcp.c b/src/tcp.c index d68db08..afa0477 100644 --- a/src/tcp.c +++ b/src/tcp.c @@ -25,6 +25,7 @@ static int meth_bind(lua_State *L); static int meth_send(lua_State *L); static int meth_getsockname(lua_State *L); static int meth_getpeername(lua_State *L); +static int meth_shutdown(lua_State *L); static int meth_receive(lua_State *L); static int meth_accept(lua_State *L); static int meth_close(lua_State *L); @@ -49,6 +50,7 @@ static luaL_reg tcp[] = { {"getsockname", meth_getsockname}, {"settimeout", meth_settimeout}, {"close", meth_close}, + {"shutdown", meth_shutdown}, {"setoption", meth_setoption}, {"__gc", meth_close}, {"fd", meth_fd}, @@ -201,12 +203,12 @@ static int meth_accept(lua_State *L) int err = IO_ERROR; p_tcp server = (p_tcp) aux_checkclass(L, "tcp{server}", 1); p_tm tm = &server->tm; - p_tcp client = lua_newuserdata(L, sizeof(t_tcp)); - aux_setclass(L, "tcp{client}", -1); + p_tcp client; + t_sock sock; tm_markstart(tm); /* loop until connection accepted or timeout happens */ while (err != IO_DONE) { - err = sock_accept(&server->sock, &client->sock, + err = sock_accept(&server->sock, &sock, (SA *) &addr, &addr_len, tm_getfailure(tm)); if (err == IO_CLOSED || (err == IO_TIMEOUT && !tm_getfailure(tm))) { lua_pushnil(L); @@ -214,6 +216,9 @@ static int meth_accept(lua_State *L) return 2; } } + client = lua_newuserdata(L, sizeof(t_tcp)); + aux_setclass(L, "tcp{client}", -1); + client->sock = sock; /* initialize remaining structure fields */ io_init(&client->io, (p_send) sock_send, (p_recv) sock_recv, &client->sock); tm_init(&client->tm, -1, -1); @@ -272,6 +277,33 @@ static int meth_close(lua_State *L) return 0; } +/*-------------------------------------------------------------------------*\ +* Shuts the connection down +\*-------------------------------------------------------------------------*/ +static int meth_shutdown(lua_State *L) +{ + p_tcp tcp = (p_tcp) aux_checkgroup(L, "tcp{any}", 1); + const char *how = luaL_optstring(L, 2, "both"); + switch (how[0]) { + case 'b': + if (strcmp(how, "both")) goto error; + sock_shutdown(&tcp->sock, 2); + break; + case 's': + if (strcmp(how, "send")) goto error; + sock_shutdown(&tcp->sock, 1); + break; + case 'r': + if (strcmp(how, "receive")) goto error; + sock_shutdown(&tcp->sock, 0); + break; + } + return 0; +error: + luaL_argerror(L, 2, "invalid shutdown method"); + return 0; +} + /*-------------------------------------------------------------------------*\ * Just call inet methods \*-------------------------------------------------------------------------*/ diff --git a/src/usocket.c b/src/usocket.c index b120d7b..f2d9f01 100644 --- a/src/usocket.c +++ b/src/usocket.c @@ -72,6 +72,14 @@ void sock_listen(p_sock ps, int backlog) listen(*ps, backlog); } +/*-------------------------------------------------------------------------*\ +* +\*-------------------------------------------------------------------------*/ +void sock_shutdown(p_sock ps, int how) +{ + shutdown(*ps, how); +} + /*-------------------------------------------------------------------------*\ * Accept with timeout \*-------------------------------------------------------------------------*/ @@ -100,39 +108,47 @@ int sock_accept(p_sock ps, p_sock pa, SA *addr, socklen_t *addr_len, /*-------------------------------------------------------------------------*\ * Send with timeout +* Here we exchanged the order of the calls to write and select +* The idea is that the outer loop (whoever is calling sock_send) +* will call the function again if we didn't time out, so we can +* call write and then select only if it fails. +* Should speed things up! +* We are also treating EINTR and EPIPE errors. \*-------------------------------------------------------------------------*/ int sock_send(p_sock ps, const char *data, size_t count, size_t *sent, int timeout) { t_sock sock = *ps; - struct timeval tv; - fd_set fds; - ssize_t put = 0; - int err; + ssize_t put; int ret; + /* avoid making system calls on closed sockets */ if (sock == SOCK_INVALID) return IO_CLOSED; - tv.tv_sec = timeout / 1000; - tv.tv_usec = (timeout % 1000) * 1000; - FD_ZERO(&fds); - FD_SET(sock, &fds); - ret = select(sock+1, NULL, &fds, NULL, timeout >= 0 ? &tv : NULL); - if (ret > 0) { - put = write(sock, data, count); - if (put <= 0) { - err = IO_CLOSED; -#ifdef __CYGWIN__ - /* this is for CYGWIN, which is like Unix but has Win32 bugs */ - if (errno == EWOULDBLOCK) err = IO_DONE; -#endif - *sent = 0; - } else { - *sent = put; - err = IO_DONE; - } - return err; - } else { + /* make sure we repeat in case the call was interrupted */ + do put = write(sock, data, count); + while (put <= 0 && errno == EINTR); + /* deal with failure */ + if (put <= 0) { + /* in any case, nothing has been sent */ *sent = 0; - return IO_TIMEOUT; + /* run select to avoid busy wait */ + if (errno != EPIPE) { + struct timeval tv; + fd_set fds; + tv.tv_sec = timeout / 1000; + tv.tv_usec = (timeout % 1000) * 1000; + FD_ZERO(&fds); + FD_SET(sock, &fds); + ret = select(sock+1, NULL, &fds, NULL, timeout >= 0 ? &tv : NULL); + /* tell the caller to call us again because there is more data */ + if (ret > 0) return IO_DONE; + /* tell the caller there was no data before timeout */ + else return IO_TIMEOUT; + /* here we know the connection has been closed */ + } else return IO_CLOSED; + /* here we sent successfully sent something */ + } else { + *sent = put; + return IO_DONE; } } @@ -176,32 +192,36 @@ int sock_sendto(p_sock ps, const char *data, size_t count, size_t *sent, /*-------------------------------------------------------------------------*\ * Receive with timeout +* Here we exchanged the order of the calls to write and select +* The idea is that the outer loop (whoever is calling sock_send) +* will call the function again if we didn't time out, so we can +* call write and then select only if it fails. +* Should speed things up! +* We are also treating EINTR errors. \*-------------------------------------------------------------------------*/ int sock_recv(p_sock ps, char *data, size_t count, size_t *got, int timeout) { t_sock sock = *ps; - struct timeval tv; - fd_set fds; - int ret; - ssize_t taken = 0; + ssize_t taken; if (sock == SOCK_INVALID) return IO_CLOSED; - tv.tv_sec = timeout / 1000; - tv.tv_usec = (timeout % 1000) * 1000; - FD_ZERO(&fds); - FD_SET(sock, &fds); - ret = select(sock+1, &fds, NULL, NULL, timeout >= 0 ? &tv : NULL); - if (ret > 0) { - taken = read(sock, data, count); - if (taken <= 0) { - *got = 0; - return IO_CLOSED; - } else { - *got = taken; - return IO_DONE; - } - } else { + do taken = read(sock, data, count); + while (taken <= 0 && errno == EINTR); + if (taken <= 0) { + struct timeval tv; + fd_set fds; + int ret; *got = 0; - return IO_TIMEOUT; + if (taken == 0) return IO_CLOSED; + tv.tv_sec = timeout / 1000; + tv.tv_usec = (timeout % 1000) * 1000; + FD_ZERO(&fds); + FD_SET(sock, &fds); + ret = select(sock+1, &fds, NULL, NULL, timeout >= 0 ? &tv : NULL); + if (ret > 0) return IO_DONE; + else return IO_TIMEOUT; + } else { + *got = taken; + return IO_DONE; } } diff --git a/src/wsocket.c b/src/wsocket.c index 1ba28b6..59d88df 100644 --- a/src/wsocket.c +++ b/src/wsocket.c @@ -35,6 +35,14 @@ void sock_destroy(p_sock ps) } } +/*-------------------------------------------------------------------------*\ +* +\*-------------------------------------------------------------------------*/ +void sock_shutdown(p_sock ps, int how) +{ + shutdown(*ps, how); +} + /*-------------------------------------------------------------------------*\ * Creates and sets up a socket \*-------------------------------------------------------------------------*/ diff --git a/test/auth/.htaccess b/test/auth/.htaccess index b9f100e..31e1123 100644 --- a/test/auth/.htaccess +++ b/test/auth/.htaccess @@ -1,4 +1,4 @@ AuthName "Test Realm" AuthType Basic -AuthUserFile /home/diego/tec/luasocket/test/auth/.htpasswd +AuthUserFile /Users/diego/tec/luasocket/test/auth/.htpasswd require valid-user diff --git a/test/httptest.lua b/test/httptest.lua index 030974c..9d9fa25 100644 --- a/test/httptest.lua +++ b/test/httptest.lua @@ -1,8 +1,8 @@ -- needs Alias from /home/c/diego/tec/luasocket/test to --- /luasocket-test +-- "/luasocket-test" and "/luasocket-test/" -- needs ScriptAlias from /home/c/diego/tec/luasocket/test/cgi --- to /luasocket-test-cgi --- needs AllowOverride AuthConfig on /home/c/diego/tec/luasocket/test/auth +-- to "/luasocket-test-cgi" and "/luasocket-test-cgi/" +-- needs "AllowOverride AuthConfig" on /home/c/diego/tec/luasocket/test/auth local similar = function(s1, s2) return string.lower(string.gsub(s1 or "", "%s", "")) == string.lower(string.gsub(s2 or "", "%s", "")) @@ -31,12 +31,18 @@ local check_request = function(request, expect, ignore) local response = socket.http.request(request) for i,v in response do if not ignore[i] then - if v ~= expect[i] then fail(i .. " differs!") end + if v ~= expect[i] then + if string.len(v) < 80 then print(v) end + fail(i .. " differs!") + end end end for i,v in expect do if not ignore[i] then - if v ~= response[i] then fail(i .. " differs!") end + if v ~= response[i] then + if string.len(v) < 80 then print(v) end + fail(i .. " differs!") + end end end print("ok") @@ -47,15 +53,18 @@ local host, request, response, ignore, expect, index, prefix, cgiprefix local t = socket.time() host = host or "localhost" -prefix = prefix or "/luasocket" -cgiprefix = cgiprefix or "/luasocket/cgi" +prefix = prefix or "/luasocket-test" +cgiprefix = cgiprefix or "/luasocket-test-cgi" index = readfile("test/index.html") io.write("testing request uri correctness: ") local forth = cgiprefix .. "/request-uri?" .. "this+is+the+query+string" -local back = socket.http.get("http://" .. host .. forth) +local back, h, c, e = socket.http.get("http://" .. host .. forth) if similar(back, forth) then print("ok") -else fail("failed!") end +else +print(h, c, e) +fail() +end io.write("testing query string correctness: ") forth = "this+is+the+query+string" @@ -77,6 +86,38 @@ ignore = { } check_request(request, expect, ignore) +socket.http.get("http://" .. host .. prefix .. "/lixo.html") + +io.write("testing post method: ") +-- wanted to test chunked post, but apache doesn't support it... +request = { + url = "http://" .. host .. cgiprefix .. "/cat", + method = "POST", + body = index, + -- remove content-length header to send chunked body + headers = { ["content-length"] = string.len(index) } +} +expect = { + body = index, + code = 200 +} +ignore = { + status = 1, + headers = 1 +} +check_request(request, expect, ignore) + +io.write("testing simple post function: ") +body = socket.http.post("http://" .. host .. cgiprefix .. "/cat", index) +check(body == index) + +io.write("testing simple post function with table args: ") +body = socket.http.post { + url = "http://" .. host .. cgiprefix .. "/cat", + body = index +} +check(body == index) + io.write("testing http redirection: ") request = { url = "http://" .. host .. prefix @@ -175,7 +216,8 @@ io.write("testing manual basic auth: ") request = { url = "http://" .. host .. prefix .. "/auth/index.html", headers = { - authorization = "Basic " .. socket.code.base64("luasocket:password") + authorization = "Basic " .. + socket.code.base64.encode("luasocket:password") } } expect = { @@ -246,22 +288,6 @@ ignore = { } check_request(request, expect, ignore) -io.write("testing post method: ") -request = { - url = "http://" .. host .. cgiprefix .. "/cat", - method = "POST", - body = index -} -expect = { - body = index, - code = 200 -} -ignore = { - status = 1, - headers = 1 -} -check_request(request, expect, ignore) - io.write("testing wrong scheme: ") request = { url = "wrong://" .. host .. cgiprefix .. "/cat", @@ -287,17 +313,6 @@ body = socket.http.get { } check(body == index) -io.write("testing simple post function: ") -body = socket.http.post("http://" .. host .. cgiprefix .. "/cat", index) -check(body == index) - -io.write("testing simple post function with table args: ") -body = socket.http.post { - url = "http://" .. host .. cgiprefix .. "/cat", - body = index -} -check(body == index) - io.write("testing HEAD method: ") response = socket.http.request { method = "HEAD", diff --git a/test/testclnt.lua b/test/testclnt.lua index 3f217bd..2420711 100644 --- a/test/testclnt.lua +++ b/test/testclnt.lua @@ -84,19 +84,19 @@ function reconnect() remote [[ if data then data:close() data = nil end data = server:accept() - -- data:setoption("nodelay", true) + data:setoption("tcp-nodelay", true) ]] data, err = socket.connect(host, port) if not data then fail(err) else pass("connected!") end - -- data:setoption("nodelay", true) + data:setoption("tcp-nodelay", true) end pass("attempting control connection...") control, err = socket.connect(host, port) if err then fail(err) else pass("connected!") end --- control:setoption("nodelay", true) +control:setoption("tcp-nodelay", true) ------------------------------------------------------------------------ test("method registration")