Changed receive function. Now uniform with all other functions. Returns nil

on error, return partial result in the end.

http.lua rewritten.
This commit is contained in:
Diego Nehab 2004-03-21 07:50:15 +00:00
parent 2a14ac4fe4
commit 4919a83d22
9 changed files with 316 additions and 651 deletions

4
TODO
View File

@ -19,6 +19,10 @@
* Separar as classes em arquivos
* Retorno de sendto em datagram sockets pode ser refused
change mime.eol to output marker on detection of first candidate, instead
of on the second. that way it works in one pass for strings that end with
one candidate.
colocar um userdata com gc metamethod pra chamar sock_close (WSAClose);
sources ans sinks are always simple in http and ftp and smtp
unify backbone of smtp and ftp

View File

@ -13,9 +13,9 @@
/*=========================================================================*\
* Internal function prototypes
\*=========================================================================*/
static int recvraw(lua_State *L, p_buf buf, size_t wanted);
static int recvline(lua_State *L, p_buf buf);
static int recvall(lua_State *L, p_buf buf);
static int recvraw(p_buf buf, size_t wanted, luaL_Buffer *b);
static int recvline(p_buf buf, luaL_Buffer *b);
static int recvall(p_buf buf, luaL_Buffer *b);
static int buf_get(p_buf buf, const char **data, size_t *count);
static void buf_skip(p_buf buf, size_t count);
static int sendraw(p_buf buf, const char *data, size_t count, size_t *sent);
@ -73,42 +73,34 @@ int buf_meth_send(lua_State *L, p_buf buf)
\*-------------------------------------------------------------------------*/
int buf_meth_receive(lua_State *L, p_buf buf)
{
int top = lua_gettop(L);
int arg, err = IO_DONE;
int err = IO_DONE, top = lua_gettop(L);
p_tm tm = buf->tm;
luaL_Buffer b;
luaL_buffinit(L, &b);
tm_markstart(tm);
/* push default pattern if need be */
if (top < 2) {
lua_pushstring(L, "*l");
top++;
}
/* make sure we have enough stack space for all returns */
luaL_checkstack(L, top+LUA_MINSTACK, "too many arguments");
/* receive all patterns */
for (arg = 2; arg <= top && err == IO_DONE; arg++) {
if (!lua_isnumber(L, arg)) {
static const char *patternnames[] = {"*l", "*a", NULL};
const char *pattern = lua_isnil(L, arg) ?
"*l" : luaL_checkstring(L, arg);
/* get next pattern */
switch (luaL_findstring(pattern, patternnames)) {
case 0: /* line pattern */
err = recvline(L, buf); break;
case 1: /* until closed pattern */
err = recvall(L, buf);
if (err == IO_CLOSED) err = IO_DONE;
break;
default: /* else it is an error */
luaL_argcheck(L, 0, arg, "invalid receive pattern");
break;
}
if (!lua_isnumber(L, 2)) {
static const char *patternnames[] = {"*l", "*a", NULL};
const char *pattern = luaL_optstring(L, 2, "*l");
/* get next pattern */
int p = luaL_findstring(pattern, patternnames);
if (p == 0) err = recvline(buf, &b);
else if (p == 1) err = recvall(buf, &b);
else luaL_argcheck(L, 0, 2, "invalid receive pattern");
/* get a fixed number of bytes */
} else err = recvraw(L, buf, (size_t) lua_tonumber(L, arg));
} else err = recvraw(buf, (size_t) lua_tonumber(L, 2), &b);
/* check if there was an error */
if (err != IO_DONE) {
luaL_pushresult(&b);
io_pusherror(L, err);
lua_pushvalue(L, -2);
lua_pushnil(L);
lua_replace(L, -4);
} else {
luaL_pushresult(&b);
lua_pushnil(L);
lua_pushnil(L);
}
/* push nil for each pattern after an error */
for ( ; arg <= top; arg++) lua_pushnil(L);
/* last return is an error code */
io_pusherror(L, err);
#ifdef LUASOCKET_DEBUG
/* push time elapsed during operation as the last return value */
lua_pushnumber(L, (tm_gettime() - tm_getstart(tm))/1000.0);
@ -150,21 +142,18 @@ int sendraw(p_buf buf, const char *data, size_t count, size_t *sent)
* Reads a fixed number of bytes (buffered)
\*-------------------------------------------------------------------------*/
static
int recvraw(lua_State *L, p_buf buf, size_t wanted)
int recvraw(p_buf buf, size_t wanted, luaL_Buffer *b)
{
int err = IO_DONE;
size_t total = 0;
luaL_Buffer b;
luaL_buffinit(L, &b);
while (total < wanted && (err == IO_DONE || err == IO_RETRY)) {
size_t count; const char *data;
err = buf_get(buf, &data, &count);
count = MIN(count, wanted - total);
luaL_addlstring(&b, data, count);
luaL_addlstring(b, data, count);
buf_skip(buf, count);
total += count;
}
luaL_pushresult(&b);
return err;
}
@ -172,19 +161,17 @@ int recvraw(lua_State *L, p_buf buf, size_t wanted)
* Reads everything until the connection is closed (buffered)
\*-------------------------------------------------------------------------*/
static
int recvall(lua_State *L, p_buf buf)
int recvall(p_buf buf, luaL_Buffer *b)
{
int err = IO_DONE;
luaL_Buffer b;
luaL_buffinit(L, &b);
while (err == IO_DONE || err == IO_RETRY) {
const char *data; size_t count;
err = buf_get(buf, &data, &count);
luaL_addlstring(&b, data, count);
luaL_addlstring(b, data, count);
buf_skip(buf, count);
}
luaL_pushresult(&b);
return err;
if (err == IO_CLOSED) return IO_DONE;
else return err;
}
/*-------------------------------------------------------------------------*\
@ -192,18 +179,16 @@ int recvall(lua_State *L, p_buf buf)
* are not returned by the function and are discarded from the buffer
\*-------------------------------------------------------------------------*/
static
int recvline(lua_State *L, p_buf buf)
int recvline(p_buf buf, luaL_Buffer *b)
{
int err = IO_DONE;
luaL_Buffer b;
luaL_buffinit(L, &b);
while (err == IO_DONE || err == IO_RETRY) {
size_t count, pos; const char *data;
err = buf_get(buf, &data, &count);
pos = 0;
while (pos < count && data[pos] != '\n') {
/* we ignore all \r's */
if (data[pos] != '\r') luaL_putchar(&b, data[pos]);
if (data[pos] != '\r') luaL_putchar(b, data[pos]);
pos++;
}
if (pos < count) { /* found '\n' */
@ -212,7 +197,6 @@ int recvline(lua_State *L, p_buf buf)
} else /* reached the end of the buffer */
buf_skip(buf, pos);
}
luaL_pushresult(&b);
return err;
}

View File

@ -39,321 +39,146 @@ local function third(a, b, c)
return c
end
-----------------------------------------------------------------------------
-- Tries to get a pattern from the server and closes socket on error
-- sock: socket connected to the server
-- pattern: pattern to receive
-- Returns
-- received pattern on success
-- nil followed by error message on error
-----------------------------------------------------------------------------
local function try_receiving(sock, pattern)
local data, err = sock:receive(pattern)
if not data then sock:close() end
--print(data)
return data, err
local function shift(a, b, c, d)
return c, d
end
-----------------------------------------------------------------------------
-- Tries to send data to the server and closes socket on error
-- sock: socket connected to the server
-- ...: data to send
-- Returns
-- err: error message if any, nil if successfull
-----------------------------------------------------------------------------
local function try_sending(sock, ...)
local sent, err = sock:send(unpack(arg))
if not sent then sock:close() end
--io.write(unpack(arg))
return err
end
-- resquest_p forward declaration
local request_p
-----------------------------------------------------------------------------
-- Receive server reply messages, parsing for status code
-- Input
-- sock: socket connected to the server
-- Returns
-- code: server status code or nil if error
-- line: full HTTP status line
-- err: error message if any
-----------------------------------------------------------------------------
local function receive_status(sock)
local line, err = try_receiving(sock)
if not err then
local code = third(string.find(line, "HTTP/%d*%.%d* (%d%d%d)"))
return tonumber(code), line
else return nil, nil, err end
end
-----------------------------------------------------------------------------
-- Receive and parse response header fields
-- Input
-- sock: socket connected to the server
-- headers: a table that might already contain headers
-- Returns
-- headers: a table with all headers fields in the form
-- {name_1 = "value_1", name_2 = "value_2" ... name_n = "value_n"}
-- all name_i are lowercase
-- nil and error message in case of error
-----------------------------------------------------------------------------
local function receive_headers(sock, headers)
local line, err
local name, value, _
headers = headers or {}
local line, name, value
-- get first line
line, err = try_receiving(sock)
if err then return nil, err end
line = socket.try(sock:receive())
-- headers go until a blank line is found
while line ~= "" do
-- get field-name and value
_,_, name, value = string.find(line, "^(.-):%s*(.*)")
if not name or not value then
sock:close()
return nil, "malformed reponse headers"
end
name, value = shift(string.find(line, "^(.-):%s*(.*)"))
assert(name and value, "malformed reponse headers")
name = string.lower(name)
-- get next line (value might be folded)
line, err = try_receiving(sock)
if err then return nil, err end
line = socket.try(sock:receive())
-- unfold any folded values
while not err and string.find(line, "^%s") do
while string.find(line, "^%s") do
value = value .. line
line, err = try_receiving(sock)
if err then return nil, err end
line = socket.try(sock:receive())
end
-- save pair in table
if headers[name] then headers[name] = headers[name] .. ", " .. value
else headers[name] = value end
end
return headers
end
-----------------------------------------------------------------------------
-- Aborts a sink with an error message
-- Input
-- cb: callback function
-- err: error message to pass to callback
-- Returns
-- callback return or if nil err
-----------------------------------------------------------------------------
local function abort(cb, err)
local go, cb_err = cb(nil, err)
return cb_err or err
error(cb_err or err)
end
-----------------------------------------------------------------------------
-- Receives a chunked message body
-- Input
-- sock: socket connected to the server
-- headers: header set in which to include trailer headers
-- sink: response message body sink
-- Returns
-- nil if successfull or an error message in case of error
-----------------------------------------------------------------------------
local function receive_body_bychunks(sock, headers, sink)
local chunk, size, line, err, go
local function hand(cb, chunk)
local go, cb_err = cb(chunk)
assert(go, cb_err or "aborted by callback")
end
local function receive_body_bychunks(sock, sink)
while 1 do
-- get chunk size, skip extention
line, err = try_receiving(sock)
if err then return abort(sink, err) end
size = tonumber(string.gsub(line, ";.*", ""), 16)
if not size then return abort(sink, "invalid chunk size") end
local line, err = sock:receive()
if err then abort(sink, err) end
local size = tonumber(string.gsub(line, ";.*", ""), 16)
if not size then abort(sink, "invalid chunk size") end
-- was it the last chunk?
if size <= 0 then break end
-- get chunk
chunk, err = try_receiving(sock, size)
if err then return abort(sink, err) end
local chunk, err = sock:receive(size)
if err then abort(sink, err) end
-- pass chunk to callback
go, err = sink(chunk)
-- see if callback aborted
if not go then return err or "aborted by callback" end
hand(sink, chunk)
-- skip CRLF on end of chunk
err = second(try_receiving(sock))
if err then return abort(sink, err) end
err = second(sock:receive())
if err then abort(sink, err) end
end
-- servers shouldn't send trailer headers, but who trusts them?
err = second(receive_headers(sock, headers))
if err then return abort(sink, err) end
-- let callback know we are done
return second(sink(nil))
hand(sink, nil)
-- servers shouldn't send trailer headers, but who trusts them?
receive_headers(sock, {})
end
-----------------------------------------------------------------------------
-- Receives a message body by content-length
-- Input
-- sock: socket connected to the server
-- length: message body length
-- sink: response message body sink
-- Returns
-- nil if successfull or an error message in case of error
-----------------------------------------------------------------------------
local function receive_body_bylength(sock, length, sink)
while length > 0 do
local size = math.min(BLOCKSIZE, length)
local chunk, err = sock:receive(size)
local go, cb_err = sink(chunk)
if err then abort(sink, err) end
length = length - string.len(chunk)
-- see if callback aborted
if not go then return cb_err or "aborted by callback" end
-- see if there was an error
if err and length > 0 then return abort(sink, err) end
hand(sink, chunk)
end
return second(sink(nil))
-- let callback know we are done
hand(sink, nil)
end
-----------------------------------------------------------------------------
-- Receives a message body until the conection is closed
-- Input
-- sock: socket connected to the server
-- sink: response message body sink
-- Returns
-- nil if successfull or an error message in case of error
-----------------------------------------------------------------------------
local function receive_body_untilclosed(sock, sink)
while 1 do
local chunk, err = sock:receive(BLOCKSIZE)
local go, cb_err = sink(chunk)
-- see if callback aborted
if not go then return cb_err or "aborted by callback" end
while true do
local chunk, err, partial = sock:receive(BLOCKSIZE)
-- see if we are done
if err == "closed" then return chunk and second(sink(nil)) end
if err == "closed" then
hand(sink, partial)
break
end
hand(sink, chunk)
-- see if there was an error
if err then return abort(sink, err) end
if err then abort(sink, err) end
end
-- let callback know we are done
hand(sink, nil)
end
-----------------------------------------------------------------------------
-- Receives the HTTP response body
-- Input
-- sock: socket connected to the server
-- headers: response header fields
-- sink: response message body sink
-- Returns
-- nil if successfull or an error message in case of error
-----------------------------------------------------------------------------
local function receive_body(sock, headers, sink)
-- make sure sink is not fancy
sink = ltn12.sink.simplify(sink)
local function receive_body(reqt, respt)
local sink = reqt.sink or ltn12.sink.null()
local headers = respt.headers
local sock = respt.tmp.sock
local te = headers["transfer-encoding"]
if te and te ~= "identity" then
-- get by chunked transfer-coding of message body
return receive_body_bychunks(sock, headers, sink)
receive_body_bychunks(sock, sink)
elseif tonumber(headers["content-length"]) then
-- get by content-length
local length = tonumber(headers["content-length"])
return receive_body_bylength(sock, length, sink)
receive_body_bylength(sock, length, sink)
else
-- get it all until connection closes
return receive_body_untilclosed(sock, sink)
receive_body_untilclosed(sock, sink)
end
end
-----------------------------------------------------------------------------
-- Sends the HTTP request message body in chunks
-- Input
-- data: data connection
-- source: request message body source
-- Returns
-- nil if successfull, or an error message in case of error
-----------------------------------------------------------------------------
local function send_body_bychunks(data, source)
while 1 do
local chunk, cb_err = source()
-- check if callback aborted
if not chunk then return cb_err or "aborted by callback" end
-- if we are done, send last-chunk
if chunk == "" then return try_sending(data, "0\r\n\r\n") end
-- else send middle chunk
local err = try_sending(data,
string.format("%X\r\n", string.len(chunk)),
chunk,
"\r\n"
)
if err then return err end
while true do
local chunk, err = source()
assert(chunk or not err, err)
if not chunk then break end
socket.try(data:send(string.format("%X\r\n", string.len(chunk))))
socket.try(data:send(chunk, "\r\n"))
end
socket.try(data:send("0\r\n\r\n"))
end
-----------------------------------------------------------------------------
-- Sends the HTTP request message body
-- Input
-- data: data connection
-- source: request message body source
-- Returns
-- nil if successfull, or an error message in case of error
-----------------------------------------------------------------------------
local function send_body(data, source)
while 1 do
local chunk, cb_err = source()
-- check if callback is done
if not chunk then return cb_err end
-- send data
local err = try_sending(data, chunk)
if err then return err end
while true do
local chunk, err = source()
assert(chunk or not err, err)
if not chunk then break end
socket.try(data:send(chunk))
end
end
-----------------------------------------------------------------------------
-- Sends request headers
-- Input
-- sock: server socket
-- headers: table with headers to be sent
-- Returns
-- err: error message if any
-----------------------------------------------------------------------------
local function send_headers(sock, headers)
local err
headers = headers or {}
-- send request headers
for i, v in headers do
err = try_sending(sock, i .. ": " .. v .. "\r\n")
if err then return err end
for i, v in pairs(headers) do
socket.try(sock:send(i .. ": " .. v .. "\r\n"))
end
-- mark end of request headers
return try_sending(sock, "\r\n")
socket.try(sock:send("\r\n"))
end
-----------------------------------------------------------------------------
-- Sends a HTTP request message through socket
-- Input
-- sock: socket connected to the server
-- method: request method to be used
-- uri: request uri
-- headers: request headers to be sent
-- source: request message body source
-- Returns
-- err: nil in case of success, error message otherwise
-----------------------------------------------------------------------------
local function send_request(sock, method, uri, headers, source)
local chunk, size, done, err
-- send request line
err = try_sending(sock, method .. " " .. uri .. " HTTP/1.1\r\n")
if err then return err end
if source and not headers["content-length"] then
headers["transfer-encoding"] = "chunked"
end
-- send request headers
err = send_headers(sock, headers)
if err then return err end
-- send request message body, if any
if source then
-- make sure source is not fancy
source = ltn12.source.simplify(source)
if headers["content-length"] then
return send_body(sock, source)
else
return send_body_bychunks(sock, source)
end
end
end
-----------------------------------------------------------------------------
-- Determines if we should read a message body from the server response
-- Input
-- reqt: a table with the original request information
-- respt: a table with the server response information
-- Returns
-- 1 if a message body should be processed, nil otherwise
-----------------------------------------------------------------------------
local function should_receive_body(reqt, respt)
if reqt.method == "HEAD" then return nil end
if respt.code == 204 or respt.code == 304 then return nil end
@ -361,125 +186,17 @@ local function should_receive_body(reqt, respt)
return 1
end
-----------------------------------------------------------------------------
-- Converts field names to lowercase and adds a few needed headers
-- Input
-- headers: request header fields
-- parsed: parsed request URL
-- Returns
-- lower: a table with the same headers, but with lowercase field names
-----------------------------------------------------------------------------
local function fill_headers(headers, parsed)
local lower = {}
headers = headers or {}
-- set default headers
lower["user-agent"] = USERAGENT
-- override with user values
for i,v in headers do
lower[string.lower(i)] = v
end
lower["host"] = parsed.host
-- this cannot be overriden
lower["connection"] = "close"
return lower
local function receive_status(reqt, respt)
local sock = respt.tmp.sock
local status = socket.try(sock:receive())
local code = third(string.find(status, "HTTP/%d*%.%d* (%d%d%d)"))
-- store results
respt.code, respt.status = tonumber(code), status
end
-----------------------------------------------------------------------------
-- Decides wether we should follow retry with authorization formation
-- Input
-- reqt: a table with the original request information
-- parsed: parsed request URL
-- respt: a table with the server response information
-- Returns
-- 1 if we should retry, nil otherwise
-----------------------------------------------------------------------------
local function should_authorize(reqt, parsed, respt)
-- if there has been an authorization attempt, it must have failed
if reqt.headers["authorization"] then return nil end
-- if we don't have authorization information, we can't retry
if parsed.user and parsed.password then return 1
else return nil end
end
-----------------------------------------------------------------------------
-- Returns the result of retrying a request with authorization information
-- Input
-- reqt: a table with the original request information
-- parsed: parsed request URL
-- Returns
-- respt: result of target authorization
-----------------------------------------------------------------------------
local function authorize(reqt, parsed)
reqt.headers["authorization"] = "Basic " ..
(mime.b64(parsed.user .. ":" .. parsed.password))
local autht = {
nredirects = reqt.nredirects,
method = reqt.method,
url = reqt.url,
source = reqt.source,
sink = reqt.sink,
headers = reqt.headers,
timeout = reqt.timeout,
proxy = reqt.proxy,
}
return request_cb(autht)
end
-----------------------------------------------------------------------------
-- Decides wether we should follow a server redirect message
-- Input
-- reqt: a table with the original request information
-- respt: a table with the server response information
-- Returns
-- 1 if we should redirect, nil otherwise
-----------------------------------------------------------------------------
local function should_redirect(reqt, respt)
return (reqt.redirect ~= false) and
(respt.code == 301 or respt.code == 302) and
(reqt.method == "GET" or reqt.method == "HEAD") and
not (reqt.nredirects and reqt.nredirects >= 5)
end
-----------------------------------------------------------------------------
-- Returns the result of a request following a server redirect message.
-- Input
-- reqt: a table with the original request information
-- respt: response table of previous attempt
-- Returns
-- respt: result of target redirection
-----------------------------------------------------------------------------
local function redirect(reqt, respt)
local nredirects = reqt.nredirects or 0
nredirects = nredirects + 1
local redirt = {
nredirects = nredirects,
method = reqt.method,
-- the RFC says the redirect URL has to be absolute, but some
-- servers do not respect that
url = socket.url.absolute(reqt.url, respt.headers["location"]),
source = reqt.source,
sink = reqt.sink,
headers = reqt.headers,
timeout = reqt.timeout,
proxy = reqt.proxy
}
respt = request_cb(redirt)
-- we pass the location header as a clue we tried to redirect
if respt.headers then respt.headers.location = redirt.url end
return respt
end
-----------------------------------------------------------------------------
-- Computes the request URI from the parsed request URL
-- If we are using a proxy, we use the absoluteURI format.
-- Otherwise, we use the abs_path format.
-- Input
-- parsed: parsed URL
-- Returns
-- uri: request URI for parsed URL
-----------------------------------------------------------------------------
local function request_uri(reqt, parsed)
local function request_uri(reqt, respt)
local url
local parsed = respt.tmp.parsed
if not reqt.proxy then
url = {
path = parsed.path,
@ -487,219 +204,187 @@ local function request_uri(reqt, parsed)
query = parsed.query,
fragment = parsed.fragment
}
else url = parsed end
else url = respt.tmp.parsed end
return socket.url.build(url)
end
-----------------------------------------------------------------------------
-- Builds a request table from a URL or request table
-- Input
-- url_or_request: target url or request table (a table with the fields:
-- url: the target URL
-- user: account user name
-- password: account password)
-- Returns
-- reqt: request table
-----------------------------------------------------------------------------
local function build_request(data)
local reqt = {}
if type(data) == "table" then
for i, v in data
do reqt[i] = v
end
else reqt.url = data end
return reqt
local function send_request(reqt, respt)
local uri = request_uri(reqt, respt)
local sock = respt.tmp.sock
local headers = respt.tmp.headers
-- send request line
socket.try(sock:send((reqt.method or "GET")
.. " " .. uri .. " HTTP/1.1\r\n"))
-- send request headers headeres
if reqt.source and not headers["content-length"] then
headers["transfer-encoding"] = "chunked"
end
send_headers(sock, headers)
-- send request message body, if any
if reqt.source then
if headers["content-length"] then send_body(sock, reqt.source)
else send_body_bychunks(sock, reqt.source) end
end
end
-----------------------------------------------------------------------------
-- Connects to a server, be it a proxy or not
-- Input
-- reqt: the request table
-- parsed: the parsed request url
-- Returns
-- sock: connection socket, or nil in case of error
-- err: error message
-----------------------------------------------------------------------------
local function try_connect(reqt, parsed)
reqt.proxy = reqt.proxy or PROXY
local function open(reqt, respt)
local parsed = respt.tmp.parsed
local proxy = reqt.proxy or PROXY
local host, port
if reqt.proxy then
local pproxy = socket.url.parse(reqt.proxy)
if not pproxy.port or not pproxy.host then
return nil, "invalid proxy"
end
if proxy then
local pproxy = socket.url.parse(proxy)
assert(pproxy.port and pproxy.host, "invalid proxy")
host, port = pproxy.host, pproxy.port
else
host, port = parsed.host, parsed.port
end
local sock, ret, err
sock, err = socket.tcp()
if not sock then return nil, err end
local sock = socket.try(socket.tcp())
-- store results
respt.tmp.sock = sock
sock:settimeout(reqt.timeout or TIMEOUT)
ret, err = sock:connect(host, port)
if not ret then
sock:close()
return nil, err
end
return sock
socket.try(sock:connect(host, port))
end
-----------------------------------------------------------------------------
-- Sends a HTTP request and retrieves the server reply using callbacks to
-- send the request body and receive the response body
-- Input
-- reqt: a table with the following fields
-- method: "GET", "PUT", "POST" etc (defaults to "GET")
-- url: target uniform resource locator
-- user, password: authentication information
-- headers: request headers to send, or nil if none
-- source: request message body source, or nil if none
-- sink: response message body sink
-- redirect: should we refrain from following a server redirect message?
-- Returns
-- respt: a table with the following fields:
-- headers: response header fields received, or nil if failed
-- status: server response status line, or nil if failed
-- code: server status code, or nil if failed
-- error: error message, or nil if successfull
-----------------------------------------------------------------------------
function request_cb(reqt)
local sock, ret
function adjust_headers(reqt, respt)
local lower = {}
local headers = reqt.headers or {}
-- set default headers
lower["user-agent"] = USERAGENT
-- override with user values
for i,v in headers do
lower[string.lower(i)] = v
end
lower["host"] = respt.tmp.parsed.host
-- this cannot be overriden
lower["connection"] = "close"
-- store results
respt.tmp.headers = lower
end
function parse_url(reqt, respt)
-- parse url with default fields
local parsed = socket.url.parse(reqt.url, {
host = "",
port = PORT,
path ="/",
scheme = "http"
})
local respt = {}
-- scheme has to be http
if parsed.scheme ~= "http" then
respt.error = string.format("unknown scheme '%s'", parsed.scheme)
return respt
end
error(string.format("unknown scheme '%s'", parsed.scheme))
end
-- explicit authentication info overrides that given by the URL
parsed.user = reqt.user or parsed.user
parsed.password = reqt.password or parsed.password
-- default method
reqt.method = reqt.method or "GET"
-- fill default headers
reqt.headers = fill_headers(reqt.headers, parsed)
-- try to connect to server
sock, respt.error = try_connect(reqt, parsed)
if not sock then return respt end
-- send request message
respt.error = send_request(sock, reqt.method,
request_uri(reqt, parsed), reqt.headers, reqt.source)
if respt.error then
sock:close()
return respt
-- store results
respt.tmp.parsed = parsed
end
local function should_authorize(reqt, respt)
-- if there has been an authorization attempt, it must have failed
if reqt.headers and reqt.headers["authorization"] then return nil end
-- if we don't have authorization information, we can't retry
return respt.tmp.parsed.user and respt.tmp.parsed.password
end
local function clone(headers)
if not headers then return nil end
local copy = {}
for i,v in pairs(headers) do
copy[i] = v
end
-- get server response message
respt.code, respt.status, respt.error = receive_status(sock)
if respt.error then return respt end
-- deal with continue 100
-- servers should not send them, but some do!
if respt.code == 100 then
respt.headers, respt.error = receive_headers(sock, {})
if respt.error then return respt end
respt.code, respt.status, respt.error = receive_status(sock)
if respt.error then return respt end
end
-- receive all headers
respt.headers, respt.error = receive_headers(sock, {})
if respt.error then return respt end
-- decide what to do based on request and response parameters
return copy
end
local function authorize(reqt, respt)
local headers = clone(reqt.headers) or {}
local parsed = respt.tmp.parsed
headers["authorization"] = "Basic " ..
(mime.b64(parsed.user .. ":" .. parsed.password))
local autht = {
method = reqt.method,
url = reqt.url,
source = reqt.source,
sink = reqt.sink,
headers = headers,
timeout = reqt.timeout,
proxy = reqt.proxy,
}
request_p(autht, respt)
end
local function should_redirect(reqt, respt)
return (reqt.redirect ~= false) and
(respt.code == 301 or respt.code == 302) and
(not reqt.method or reqt.method == "GET" or reqt.method == "HEAD")
and (not respt.tmp.nredirects or respt.tmp.nredirects < 5)
end
local function redirect(reqt, respt)
respt.tmp.nredirects = (respt.tmp.nredirects or 0) + 1
local redirt = {
method = reqt.method,
-- the RFC says the redirect URL has to be absolute, but some
-- servers do not respect that
url = socket.url.absolute(reqt.url, respt.headers["location"]),
source = reqt.source,
sink = reqt.sink,
headers = reqt.headers,
timeout = reqt.timeout,
proxy = reqt.proxy
}
request_p(redirt, respt)
-- we pass the location header as a clue we redirected
if respt.headers then respt.headers.location = redirt.url end
end
function request_p(reqt, respt)
parse_url(reqt, respt)
adjust_headers(reqt, respt)
open(reqt, respt)
send_request(reqt, respt)
receive_status(reqt, respt)
respt.headers = {}
receive_headers(respt.tmp.sock, respt.headers)
if should_redirect(reqt, respt) then
-- drop the body
receive_body(sock, respt.headers, ltn12.sink.null())
-- we are done with this connection
sock:close()
return redirect(reqt, respt)
elseif should_authorize(reqt, parsed, respt) then
-- drop the body
receive_body(sock, respt.headers, ltn12.sink.null())
-- we are done with this connection
sock:close()
return authorize(reqt, parsed, respt)
respt.tmp.sock:close()
redirect(reqt, respt)
elseif should_authorize(reqt, respt) then
respt.tmp.sock:close()
authorize(reqt, respt)
elseif should_receive_body(reqt, respt) then
respt.error = receive_body(sock, respt.headers, reqt.sink)
if respt.error then return respt end
sock:close()
return respt
receive_body(reqt, respt)
end
sock:close()
return respt
end
-----------------------------------------------------------------------------
-- Sends a HTTP request and retrieves the server reply
-- Input
-- reqt: a table with the following fields
-- method: "GET", "PUT", "POST" etc (defaults to "GET")
-- url: request URL, i.e. the document to be retrieved
-- user, password: authentication information
-- headers: request header fields, or nil if none
-- body: request message body as a string, or nil if none
-- redirect: should we refrain from following a server redirect message?
-- Returns
-- respt: a table with the following fields:
-- body: response message body, or nil if failed
-- headers: response header fields, or nil if failed
-- status: server response status line, or nil if failed
-- code: server response status code, or nil if failed
-- error: error message if any
-----------------------------------------------------------------------------
function request(reqt)
reqt.source = reqt.body and ltn12.source.string(reqt.body)
local t = {}
reqt.sink = ltn12.sink.table(t)
local respt = request_cb(reqt)
if table.getn(t) > 0 then respt.body = table.concat(t) end
local respt = { tmp = {} }
local s, e = pcall(request_p, reqt, respt)
if not s then respt.error = e end
if respt.tmp.sock then respt.tmp.sock:close() end
respt.tmp = nil
return respt
end
-----------------------------------------------------------------------------
-- Retrieves a URL by the method "GET"
-- Input
-- url_or_request: target url or request table (a table with the fields:
-- url: the target URL
-- user: account user name
-- password: account password)
-- Returns
-- body: response message body, or nil if failed
-- headers: response header fields received, or nil if failed
-- code: server response status code, or nil if failed
-- error: error message if any
-----------------------------------------------------------------------------
function get(url_or_request)
local reqt = build_request(url_or_request)
reqt.method = "GET"
local respt = request(reqt)
return respt.body, respt.headers, respt.code, respt.error
function get(url)
local t = {}
respt = request {
url = url,
sink = ltn12.sink.table(t)
}
return table.getn(t) > 0 and table.concat(t), respt.headers,
respt.code, respt.error
end
-----------------------------------------------------------------------------
-- Retrieves a URL by the method "POST"
-- Input
-- url_or_request: target url or request table (a table with the fields:
-- url: the target URL
-- body: request message body
-- user: account user name
-- password: account password)
-- body: request message body, or nil if none
-- Returns
-- body: response message body, or nil if failed
-- headers: response header fields received, or nil if failed
-- code: server response status code, or nil if failed
-- error: error message, or nil if successfull
-----------------------------------------------------------------------------
function post(url_or_request, body)
local reqt = build_request(url_or_request)
reqt.method = "POST"
reqt.body = reqt.body or body
reqt.headers = reqt.headers or
{ ["content-length"] = string.len(reqt.body) }
local respt = request(reqt)
return respt.body, respt.headers, respt.code, respt.error
function post(url, body)
local t = {}
respt = request {
url = url,
method = "POST",
source = ltn12.source.string(body),
sink = ltn12.sink.table(t),
headers = { ["content-length"] = string.len(body) }
}
return table.getn(t) > 0 and table.concat(t),
respt.headers, respt.code, respt.error
end
return socket.http

View File

@ -171,9 +171,8 @@ function sink.file(handle, io_err)
return function(chunk, err)
if not chunk then
handle:close()
return nil, err
end
return handle:write(chunk)
return 1
else return handle:write(chunk) end
end
else return sink.error(io_err or "unable to open file") end
end

View File

@ -619,28 +619,27 @@ static int mime_global_qpwrp(lua_State *L)
* end of line markers each, but \r\n, \n\r etc will only issue *one*
* marker. This covers Mac OS, Mac OS X, VMS, Unix and DOS, as well as
* probably other more obscure conventions.
*
* c is the current character being processed
* last is the previous character
\*-------------------------------------------------------------------------*/
#define eolcandidate(c) (c == CR || c == LF)
static size_t eolprocess(int c, int ctx, const char *marker,
static int eolprocess(int c, int last, const char *marker,
luaL_Buffer *buffer)
{
if (eolcandidate(ctx)) {
luaL_addstring(buffer, marker);
if (eolcandidate(c)) {
if (c == ctx)
luaL_addstring(buffer, marker);
if (eolcandidate(c)) {
if (eolcandidate(last)) {
if (c == last) luaL_addstring(buffer, marker);
return 0;
} else {
luaL_putchar(buffer, c);
return 0;
luaL_addstring(buffer, marker);
return c;
}
} else {
if (!eolcandidate(c)) {
luaL_putchar(buffer, c);
return 0;
} else
return c;
luaL_putchar(buffer, c);
return 0;
}
}
/*-------------------------------------------------------------------------*\
@ -661,8 +660,7 @@ static int mime_global_eol(lua_State *L)
luaL_buffinit(L, &buffer);
/* if the last character was a candidate, we output a new line */
if (!input) {
if (eolcandidate(ctx)) lua_pushstring(L, marker);
else lua_pushnil(L);
lua_pushnil(L);
lua_pushnumber(L, 0);
return 2;
}

View File

@ -8,7 +8,7 @@ dofile("noglobals.lua")
local host, proxy, request, response, index_file
local ignore, expect, index, prefix, cgiprefix, index_crlf
socket.http.TIMEOUT = 5
socket.http.TIMEOUT = 10
local t = socket.time()
@ -49,7 +49,9 @@ local check_result = function(response, expect, ignore)
for i,v in response do
if not ignore[i] then
if v ~= expect[i] then
print(string.sub(tostring(v), 1, 70))
local f = io.open("err", "w")
f:write(tostring(v), "\n\n versus\n\n", tostring(expect[i]))
f:close()
fail(i .. " differs!")
end
end
@ -57,8 +59,10 @@ local check_result = function(response, expect, ignore)
for i,v in expect do
if not ignore[i] then
if v ~= response[i] then
local f = io.open("err", "w")
f:write(tostring(response[i]), "\n\n versus\n\n", tostring(v))
v = string.sub(type(v) == "string" and v or "", 1, 70)
print(string.sub(tostring(v), 1, 70))
f:close()
fail(i .. " differs!")
end
end
@ -67,12 +71,14 @@ local check_result = function(response, expect, ignore)
end
local check_request = function(request, expect, ignore)
local t
if not request.sink then
request.sink, t = ltn12.sink.table(t)
end
request.source = request.source or
(request.body and ltn12.source.string(request.body))
local response = socket.http.request(request)
check_result(response, expect, ignore)
end
local check_request_cb = function(request, expect, ignore)
local response = socket.http.request_cb(request)
if t and table.getn(t) > 0 then response.body = table.concat(t) end
check_result(response, expect, ignore)
end
@ -183,7 +189,7 @@ ignore = {
status = 1,
headers = 1
}
check_request_cb(request, expect, ignore)
check_request(request, expect, ignore)
back = readfile(index_file .. "-back")
check(back == index)
os.remove(index_file .. "-back")
@ -225,19 +231,11 @@ ignore = {
status = 1,
headers = 1
}
check_request_cb(request, expect, ignore)
check_request(request, expect, ignore)
back = readfile(index_file .. "-back")
check(back == index)
os.remove(index_file .. "-back")
------------------------------------------------------------------------
io.write("testing simple post function with table args: ")
back = socket.http.post {
url = "http://" .. host .. cgiprefix .. "/cat",
body = index
}
check(back == index)
------------------------------------------------------------------------
io.write("testing http redirection: ")
request = {
@ -438,15 +436,6 @@ io.write("testing simple get function: ")
body = socket.http.get("http://" .. host .. prefix .. "/index.html")
check(body == index)
------------------------------------------------------------------------
io.write("testing simple get function with table args: ")
body = socket.http.get {
url = "http://really:wrong@" .. host .. prefix .. "/auth/index.html",
user = "luasocket",
password = "password"
}
check(body == index)
------------------------------------------------------------------------
io.write("testing HEAD method: ")
socket.http.TIMEOUT = 1

View File

@ -17,14 +17,12 @@ function warn(...)
io.stderr:write("WARNING: ", s, "\n")
end
pad = string.rep(" ", 8192)
function remote(...)
local s = string.format(unpack(arg))
s = string.gsub(s, "\n", ";")
s = string.gsub(s, "%s+", " ")
s = string.gsub(s, "^%s*", "")
control:send(pad, s, "\n")
control:send(s, "\n")
control:receive()
end
@ -122,7 +120,13 @@ remote (string.format("str = data:receive(%d)",
sent, err = data:send(p1, p2, p3, p4)
if err then fail(err) end
remote "data:send(str); data:close()"
bp1, bp2, bp3, bp4, err = data:receive("*l", "*l", string.len(p3), "*a")
bp1, err = data:receive()
if err then fail(err) end
bp2, err = data:receive()
if err then fail(err) end
bp3, err = data:receive(string.len(p3))
if err then fail(err) end
bp4, err = data:receive("*a")
if err then fail(err) end
if bp1.."\n" == p1 and bp2.."\r\n" == p2 and bp3 == p3 and bp4 == p4 then
pass("patterns match")
@ -186,7 +190,7 @@ end
------------------------------------------------------------------------
function test_totaltimeoutreceive(len, tm, sl)
reconnect()
local str, err, total
local str, err, partial
pass("%d bytes, %ds total timeout, %ds pause", len, tm, sl)
remote (string.format ([[
data:settimeout(%d)
@ -198,9 +202,9 @@ function test_totaltimeoutreceive(len, tm, sl)
data:send(str)
]], 2*tm, len, sl, sl))
data:settimeout(tm, "total")
str, err, elapsed = data:receive(2*len)
str, err, partial, elapsed = data:receive(2*len)
check_timeout(tm, sl, elapsed, err, "receive", "total",
string.len(str) == 2*len)
string.len(str or partial) == 2*len)
end
------------------------------------------------------------------------
@ -226,7 +230,7 @@ end
------------------------------------------------------------------------
function test_blockingtimeoutreceive(len, tm, sl)
reconnect()
local str, err, total
local str, err, partial
pass("%d bytes, %ds blocking timeout, %ds pause", len, tm, sl)
remote (string.format ([[
data:settimeout(%d)
@ -238,9 +242,9 @@ function test_blockingtimeoutreceive(len, tm, sl)
data:send(str)
]], 2*tm, len, sl, sl))
data:settimeout(tm)
str, err, elapsed = data:receive(2*len)
str, err, partial, elapsed = data:receive(2*len)
check_timeout(tm, sl, elapsed, err, "receive", "blocking",
string.len(str) == 2*len)
string.len(str or partial) == 2*len)
end
------------------------------------------------------------------------
@ -298,7 +302,7 @@ end
------------------------------------------------------------------------
function test_closed()
local back, err
local back, partial, err
local str = 'little string'
reconnect()
pass("trying read detection")
@ -308,10 +312,10 @@ function test_closed()
data = nil
]], str))
-- try to get a line
back, err = data:receive()
if not err then fail("shold have gotten 'closed'.")
back, err, partial = data:receive()
if not err then fail("should have gotten 'closed'.")
elseif err ~= "closed" then fail("got '"..err.."' instead of 'closed'.")
elseif str ~= back then fail("didn't receive partial result.")
elseif str ~= partial then fail("didn't receive partial result.")
else pass("graceful 'closed' received") end
reconnect()
pass("trying write detection")
@ -456,7 +460,6 @@ test_methods(socket.udp(), {
"setpeername",
"setsockname",
"settimeout",
"shutdown",
})
test("select function")
@ -481,6 +484,7 @@ accept_timeout()
accept_errors()
test("mixed patterns")
test_mixed(1)
test_mixed(17)

View File

@ -6,7 +6,7 @@ mesgt = {
body = {
preamble = "Some attatched stuff",
[1] = {
body = "Testing stuffing.\r\n.\r\nGot you.\r\n.Hehehe.\r\n"
body = mime.eol(0, "Testing stuffing.\n.\nGot you.\n.Hehehe.\n")
},
[2] = {
headers = {
@ -29,7 +29,7 @@ mesgt = {
["content-transfer-encoding"] = "QUOTED-PRINTABLE"
},
body = ltn12.source.chain(
ltn12.source.file(io.open("message.lua", "rb")),
ltn12.source.file(io.open("testmesg.lua", "rb")),
ltn12.filter.chain(
mime.normalize(),
mime.encode("quoted-printable"),
@ -46,8 +46,8 @@ mesgt = {
-- ltn12.pump(source, sink)
print(socket.smtp.send {
rcpt = {"<db@werx4.com>", "<diego@cs.princeton.edu>"},
rcpt = "<diego@cs.princeton.edu>",
from = "<diego@cs.princeton.edu>",
source = socket.smtp.message(mesgt),
server = "smtp.princeton.edu"
server = "mail.cs.princeton.edu"
})

View File

@ -22,6 +22,8 @@ while 1 do
print("server: closing connection...")
break
end
print(command);
(loadstring(command))()
end
end