Merge remote-tracking branch 'upstream/master'

This commit is contained in:
Bruno Silvestre 2015-02-11 09:05:48 -02:00
commit c464a9218b
4 changed files with 54 additions and 28 deletions

View File

@ -5,6 +5,7 @@
----------------------------------------------------------------------------- -----------------------------------------------------------------------------
local base = _G local base = _G
local table = require("table") local table = require("table")
local string = require("string")
local socket = require("socket") local socket = require("socket")
local coroutine = require("coroutine") local coroutine = require("coroutine")
module("dispatch") module("dispatch")
@ -43,25 +44,31 @@ end
----------------------------------------------------------------------------- -----------------------------------------------------------------------------
-- Mega hack. Don't try to do this at home. -- Mega hack. Don't try to do this at home.
----------------------------------------------------------------------------- -----------------------------------------------------------------------------
-- we can't yield across calls to protect, so we rewrite it with coxpcall -- we can't yield across calls to protect on Lua 5.1, so we rewrite it with
-- coroutines
-- make sure you don't require any module that uses socket.protect before -- make sure you don't require any module that uses socket.protect before
-- loading our hack -- loading our hack
if string.sub(base._VERSION, -3) == "5.1" then
local function _protect(co, status, ...)
if not status then
local msg = ...
if base.type(msg) == 'table' then
return nil, msg[1]
else
base.error(msg, 0)
end
end
if coroutine.status(co) == "suspended" then
return _protect(co, coroutine.resume(co, coroutine.yield(...)))
else
return ...
end
end
function socket.protect(f) function socket.protect(f)
return function(...) return function(...)
local co = coroutine.create(f) local co = coroutine.create(f)
while true do return _protect(co, coroutine.resume(co, ...))
local results = {coroutine.resume(co, ...)}
local status = table.remove(results, 1)
if not status then
if base.type(results[1]) == 'table' then
return nil, results[1][1]
else base.error(results[1]) end
end
if coroutine.status(co) == "suspended" then
arg = {coroutine.yield(base.unpack(results))}
else
return base.unpack(results)
end
end end
end end
end end

View File

@ -9,6 +9,15 @@
#include "except.h" #include "except.h"
#if LUA_VERSION_NUM < 502
#define lua_pcallk(L, na, nr, err, ctx, cont) \
((void)ctx,(void)cont,lua_pcall(L, na, nr, err))
#endif
#if LUA_VERSION_NUM < 503
typedef int lua_KContext;
#endif
/*=========================================================================*\ /*=========================================================================*\
* Internal function prototypes. * Internal function prototypes.
\*=========================================================================*/ \*=========================================================================*/
@ -73,14 +82,30 @@ static int unwrap(lua_State *L) {
} else return 0; } else return 0;
} }
static int protected_finish(lua_State *L, int status, lua_KContext ctx) {
(void)ctx;
if (status != 0 && status != LUA_YIELD) {
if (unwrap(L)) return 2;
else return lua_error(L);
} else return lua_gettop(L);
}
#if LUA_VERSION_NUM == 502
static int protected_cont(lua_State *L) {
int ctx = 0;
int status = lua_getctx(L, &ctx);
return protected_finish(L, status, ctx);
}
#else
#define protected_cont protected_finish
#endif
static int protected_(lua_State *L) { static int protected_(lua_State *L) {
int status;
lua_pushvalue(L, lua_upvalueindex(1)); lua_pushvalue(L, lua_upvalueindex(1));
lua_insert(L, 1); lua_insert(L, 1);
if (lua_pcall(L, lua_gettop(L) - 1, LUA_MULTRET, 0) != 0) { status = lua_pcallk(L, lua_gettop(L) - 1, LUA_MULTRET, 0, 0, protected_cont);
if (unwrap(L)) return 2; return protected_finish(L, status, 0);
else lua_error(L);
return 0;
} else return lua_gettop(L);
} }
static int global_protect(lua_State *L) { static int global_protect(lua_State *L) {

View File

@ -209,8 +209,7 @@ end
local function adjustheaders(reqt) local function adjustheaders(reqt)
-- default headers -- default headers
local host = reqt.host local host = string.gsub(reqt.authority, "^.-@", "")
if reqt.port then host = host .. ":" .. reqt.port end
local lower = { local lower = {
["user-agent"] = _M.USERAGENT, ["user-agent"] = _M.USERAGENT,
["host"] = host, ["host"] = host,

View File

@ -82,7 +82,6 @@ int socket_close(void) {
\*-------------------------------------------------------------------------*/ \*-------------------------------------------------------------------------*/
void socket_destroy(p_socket ps) { void socket_destroy(p_socket ps) {
if (*ps != SOCKET_INVALID) { if (*ps != SOCKET_INVALID) {
socket_setblocking(ps);
close(*ps); close(*ps);
*ps = SOCKET_INVALID; *ps = SOCKET_INVALID;
} }
@ -130,9 +129,7 @@ int socket_bind(p_socket ps, SA *addr, socklen_t len) {
\*-------------------------------------------------------------------------*/ \*-------------------------------------------------------------------------*/
int socket_listen(p_socket ps, int backlog) { int socket_listen(p_socket ps, int backlog) {
int err = IO_DONE; int err = IO_DONE;
socket_setblocking(ps);
if (listen(*ps, backlog)) err = errno; if (listen(*ps, backlog)) err = errno;
socket_setnonblocking(ps);
return err; return err;
} }
@ -140,9 +137,7 @@ int socket_listen(p_socket ps, int backlog) {
* *
\*-------------------------------------------------------------------------*/ \*-------------------------------------------------------------------------*/
void socket_shutdown(p_socket ps, int how) { void socket_shutdown(p_socket ps, int how) {
socket_setblocking(ps);
shutdown(*ps, how); shutdown(*ps, how);
socket_setnonblocking(ps);
} }
/*-------------------------------------------------------------------------*\ /*-------------------------------------------------------------------------*\