diff --git a/src/except.c b/src/except.c index 002e701..4faa208 100644 --- a/src/except.c +++ b/src/except.c @@ -9,6 +9,15 @@ #include "except.h" +#if LUA_VERSION_NUM < 502 +#define lua_pcallk(L, na, nr, err, ctx, cont) \ + ((void)ctx,(void)cont,lua_pcall(L, na, nr, err)) +#endif + +#if LUA_VERSION_NUM < 503 +typedef int lua_KContext; +#endif + /*=========================================================================*\ * Internal function prototypes. \*=========================================================================*/ @@ -73,14 +82,30 @@ static int unwrap(lua_State *L) { } else return 0; } +static int protected_finish(lua_State *L, int status, lua_KContext ctx) { + (void)ctx; + if (status != 0 && status != LUA_YIELD) { + if (unwrap(L)) return 2; + else return lua_error(L); + } else return lua_gettop(L); +} + +#if LUA_VERSION_NUM == 502 +static int protected_cont(lua_State *L) { + int ctx = 0; + int status = lua_getctx(L, &ctx); + return protected_finish(L, status, ctx); +} +#else +#define protected_cont protected_finish +#endif + static int protected_(lua_State *L) { + int status; lua_pushvalue(L, lua_upvalueindex(1)); lua_insert(L, 1); - if (lua_pcall(L, lua_gettop(L) - 1, LUA_MULTRET, 0) != 0) { - if (unwrap(L)) return 2; - else lua_error(L); - return 0; - } else return lua_gettop(L); + status = lua_pcallk(L, lua_gettop(L) - 1, LUA_MULTRET, 0, 0, protected_cont); + return protected_finish(L, status, 0); } static int global_protect(lua_State *L) {