Merge pull request #113 from siffiejoe/yieldable_protect51

fixed yieldable socket.protect in etc/dispatch.lua
This commit is contained in:
Diego Nehab 2014-11-10 15:39:34 -02:00
commit 583257c28c

View File

@ -5,6 +5,7 @@
----------------------------------------------------------------------------- -----------------------------------------------------------------------------
local base = _G local base = _G
local table = require("table") local table = require("table")
local string = require("string")
local socket = require("socket") local socket = require("socket")
local coroutine = require("coroutine") local coroutine = require("coroutine")
module("dispatch") module("dispatch")
@ -43,26 +44,32 @@ end
----------------------------------------------------------------------------- -----------------------------------------------------------------------------
-- Mega hack. Don't try to do this at home. -- Mega hack. Don't try to do this at home.
----------------------------------------------------------------------------- -----------------------------------------------------------------------------
-- we can't yield across calls to protect, so we rewrite it with coxpcall -- we can't yield across calls to protect on Lua 5.1, so we rewrite it with
-- coroutines
-- make sure you don't require any module that uses socket.protect before -- make sure you don't require any module that uses socket.protect before
-- loading our hack -- loading our hack
function socket.protect(f) if string.sub(base._VERSION, -3) == "5.1" then
return function(...) local function _protect(co, status, ...)
local co = coroutine.create(f) if not status then
while true do local msg = ...
local results = {coroutine.resume(co, ...)} if base.type(msg) == 'table' then
local status = table.remove(results, 1) return nil, msg[1]
if not status then
if base.type(results[1]) == 'table' then
return nil, results[1][1]
else base.error(results[1]) end
end
if coroutine.status(co) == "suspended" then
arg = {coroutine.yield(base.unpack(results))}
else else
return base.unpack(results) base.error(msg, 0)
end end
end end
if coroutine.status(co) == "suspended" then
return _protect(co, coroutine.resume(co, coroutine.yield(...)))
else
return ...
end
end
function socket.protect(f)
return function(...)
local co = coroutine.create(f)
return _protect(co, coroutine.resume(co, ...))
end
end end
end end