From 7006ae120da2f8ec74323422541cd585c91832d6 Mon Sep 17 00:00:00 2001 From: Philipp Janda Date: Mon, 10 Nov 2014 18:17:10 +0100 Subject: [PATCH] fixed yieldable socket.protect in etc/dispatch.lua --- etc/dispatch.lua | 37 ++++++++++++++++++++++--------------- 1 file changed, 22 insertions(+), 15 deletions(-) diff --git a/etc/dispatch.lua b/etc/dispatch.lua index cab7f59..2485415 100644 --- a/etc/dispatch.lua +++ b/etc/dispatch.lua @@ -5,6 +5,7 @@ ----------------------------------------------------------------------------- local base = _G local table = require("table") +local string = require("string") local socket = require("socket") local coroutine = require("coroutine") module("dispatch") @@ -43,26 +44,32 @@ end ----------------------------------------------------------------------------- -- Mega hack. Don't try to do this at home. ----------------------------------------------------------------------------- --- we can't yield across calls to protect, so we rewrite it with coxpcall +-- we can't yield across calls to protect on Lua 5.1, so we rewrite it with +-- coroutines -- make sure you don't require any module that uses socket.protect before -- loading our hack -function socket.protect(f) - return function(...) - local co = coroutine.create(f) - while true do - local results = {coroutine.resume(co, ...)} - local status = table.remove(results, 1) - if not status then - if base.type(results[1]) == 'table' then - return nil, results[1][1] - else base.error(results[1]) end - end - if coroutine.status(co) == "suspended" then - arg = {coroutine.yield(base.unpack(results))} +if string.sub(base._VERSION, -3) == "5.1" then + local function _protect(co, status, ...) + if not status then + local msg = ... + if base.type(msg) == 'table' then + return nil, msg[1] else - return base.unpack(results) + base.error(msg, 0) end end + if coroutine.status(co) == "suspended" then + return _protect(co, coroutine.resume(co, coroutine.yield(...))) + else + return ... + end + end + + function socket.protect(f) + return function(...) + local co = coroutine.create(f) + return _protect(co, coroutine.resume(co, ...)) + end end end