diff --git a/etc/dispatch.lua b/etc/dispatch.lua index cab7f59..2485415 100644 --- a/etc/dispatch.lua +++ b/etc/dispatch.lua @@ -5,6 +5,7 @@ ----------------------------------------------------------------------------- local base = _G local table = require("table") +local string = require("string") local socket = require("socket") local coroutine = require("coroutine") module("dispatch") @@ -43,26 +44,32 @@ end ----------------------------------------------------------------------------- -- Mega hack. Don't try to do this at home. ----------------------------------------------------------------------------- --- we can't yield across calls to protect, so we rewrite it with coxpcall +-- we can't yield across calls to protect on Lua 5.1, so we rewrite it with +-- coroutines -- make sure you don't require any module that uses socket.protect before -- loading our hack -function socket.protect(f) - return function(...) - local co = coroutine.create(f) - while true do - local results = {coroutine.resume(co, ...)} - local status = table.remove(results, 1) - if not status then - if base.type(results[1]) == 'table' then - return nil, results[1][1] - else base.error(results[1]) end - end - if coroutine.status(co) == "suspended" then - arg = {coroutine.yield(base.unpack(results))} +if string.sub(base._VERSION, -3) == "5.1" then + local function _protect(co, status, ...) + if not status then + local msg = ... + if base.type(msg) == 'table' then + return nil, msg[1] else - return base.unpack(results) + base.error(msg, 0) end end + if coroutine.status(co) == "suspended" then + return _protect(co, coroutine.resume(co, coroutine.yield(...))) + else + return ... + end + end + + function socket.protect(f) + return function(...) + local co = coroutine.create(f) + return _protect(co, coroutine.resume(co, ...)) + end end end diff --git a/src/except.c b/src/except.c index 002e701..4faa208 100644 --- a/src/except.c +++ b/src/except.c @@ -9,6 +9,15 @@ #include "except.h" +#if LUA_VERSION_NUM < 502 +#define lua_pcallk(L, na, nr, err, ctx, cont) \ + ((void)ctx,(void)cont,lua_pcall(L, na, nr, err)) +#endif + +#if LUA_VERSION_NUM < 503 +typedef int lua_KContext; +#endif + /*=========================================================================*\ * Internal function prototypes. \*=========================================================================*/ @@ -73,14 +82,30 @@ static int unwrap(lua_State *L) { } else return 0; } +static int protected_finish(lua_State *L, int status, lua_KContext ctx) { + (void)ctx; + if (status != 0 && status != LUA_YIELD) { + if (unwrap(L)) return 2; + else return lua_error(L); + } else return lua_gettop(L); +} + +#if LUA_VERSION_NUM == 502 +static int protected_cont(lua_State *L) { + int ctx = 0; + int status = lua_getctx(L, &ctx); + return protected_finish(L, status, ctx); +} +#else +#define protected_cont protected_finish +#endif + static int protected_(lua_State *L) { + int status; lua_pushvalue(L, lua_upvalueindex(1)); lua_insert(L, 1); - if (lua_pcall(L, lua_gettop(L) - 1, LUA_MULTRET, 0) != 0) { - if (unwrap(L)) return 2; - else lua_error(L); - return 0; - } else return lua_gettop(L); + status = lua_pcallk(L, lua_gettop(L) - 1, LUA_MULTRET, 0, 0, protected_cont); + return protected_finish(L, status, 0); } static int global_protect(lua_State *L) { diff --git a/src/http.lua b/src/http.lua index 1d0eb50..d5457f6 100644 --- a/src/http.lua +++ b/src/http.lua @@ -209,8 +209,7 @@ end local function adjustheaders(reqt) -- default headers - local host = reqt.host - if reqt.port then host = host .. ":" .. reqt.port end + local host = string.gsub(reqt.authority, "^.-@", "") local lower = { ["user-agent"] = _M.USERAGENT, ["host"] = host, @@ -353,4 +352,4 @@ _M.request = socket.protect(function(reqt, body) else return trequest(reqt) end end) -return _M \ No newline at end of file +return _M diff --git a/src/usocket.c b/src/usocket.c index da09130..4fe333e 100644 --- a/src/usocket.c +++ b/src/usocket.c @@ -82,7 +82,6 @@ int socket_close(void) { \*-------------------------------------------------------------------------*/ void socket_destroy(p_socket ps) { if (*ps != SOCKET_INVALID) { - socket_setblocking(ps); close(*ps); *ps = SOCKET_INVALID; } @@ -130,9 +129,7 @@ int socket_bind(p_socket ps, SA *addr, socklen_t len) { \*-------------------------------------------------------------------------*/ int socket_listen(p_socket ps, int backlog) { int err = IO_DONE; - socket_setblocking(ps); if (listen(*ps, backlog)) err = errno; - socket_setnonblocking(ps); return err; } @@ -140,9 +137,7 @@ int socket_listen(p_socket ps, int backlog) { * \*-------------------------------------------------------------------------*/ void socket_shutdown(p_socket ps, int how) { - socket_setblocking(ps); shutdown(*ps, how); - socket_setnonblocking(ps); } /*-------------------------------------------------------------------------*\