diff --git a/doc/socket.html b/doc/socket.html index 8a81414..a43a208 100644 --- a/doc/socket.html +++ b/doc/socket.html @@ -220,13 +220,6 @@ Returns an equivalent function that instead of throwing exceptions, returns nil followed by an error message.
--Note: Beware that if your function performs some illegal operation that -raises an error, the protected function will catch the error and return it -as a string. This is because the try function -uses errors as the mechanism to throw exceptions. -
-@@ -424,8 +417,7 @@ socket.try(ret1 [, ret2 ... retN]) Throws an exception in case of error. The exception can only be caught -by the protect function. It does not explode -into an error message. +by the protect function.
@@ -436,7 +428,10 @@ nested with try.
The function returns ret1 to retN if -ret1 is not nil. Otherwise, it calls error passing ret2. +ret1 is not nil or false. +Otherwise, it calls error passing ret2 wrapped +in a table with metatable used by protect to +distinguish exceptions from runtime errors.
diff --git a/src/except.c b/src/except.c index 261ac98..def35a0 100644 --- a/src/except.c +++ b/src/except.c @@ -12,7 +12,7 @@ #if LUA_VERSION_NUM < 502 #define lua_pcallk(L, na, nr, err, ctx, cont) \ - ((void)ctx,(void)cont,lua_pcall(L, na, nr, err)) + (((void)ctx),((void)cont),lua_pcall(L, na, nr, err)) #endif #if LUA_VERSION_NUM < 503 @@ -39,12 +39,11 @@ static luaL_Reg func[] = { * Try factory \*-------------------------------------------------------------------------*/ static void wrap(lua_State *L) { - lua_newtable(L); - lua_pushnumber(L, 1); - lua_pushvalue(L, -3); - lua_settable(L, -3); - lua_insert(L, -2); - lua_pop(L, 1); + lua_createtable(L, 1, 0); + lua_pushvalue(L, -2); + lua_rawseti(L, -2, 1); + lua_pushvalue(L, lua_upvalueindex(2)); + lua_setmetatable(L, -2); } static int finalize(lua_State *L) { @@ -58,15 +57,16 @@ static int finalize(lua_State *L) { } else return lua_gettop(L); } -static int do_nothing(lua_State *L) { +static int do_nothing(lua_State *L) { (void) L; - return 0; + return 0; } static int global_newtry(lua_State *L) { lua_settop(L, 1); if (lua_isnil(L, 1)) lua_pushcfunction(L, do_nothing); - lua_pushcclosure(L, finalize, 1); + lua_pushvalue(L, lua_upvalueindex(1)); + lua_pushcclosure(L, finalize, 2); return 1; } @@ -74,13 +74,16 @@ static int global_newtry(lua_State *L) { * Protect factory \*-------------------------------------------------------------------------*/ static int unwrap(lua_State *L) { - if (lua_istable(L, -1)) { - lua_pushnumber(L, 1); - lua_gettable(L, -2); - lua_pushnil(L); - lua_insert(L, -2); - return 1; - } else return 0; + if (lua_istable(L, -1) && lua_getmetatable(L, -1)) { + int r = lua_rawequal(L, -1, lua_upvalueindex(2)); + lua_pop(L, 1); + if (r) { + lua_pushnil(L); + lua_rawgeti(L, -2, 1); + return 1; + } + } + return 0; } static int protected_finish(lua_State *L, int status, lua_KContext ctx) { @@ -110,7 +113,9 @@ static int protected_(lua_State *L) { } static int global_protect(lua_State *L) { - lua_pushcclosure(L, protected_, 1); + lua_settop(L, 1); + lua_pushvalue(L, lua_upvalueindex(1)); + lua_pushcclosure(L, protected_, 2); return 1; } @@ -118,6 +123,9 @@ static int global_protect(lua_State *L) { * Init module \*-------------------------------------------------------------------------*/ int except_open(lua_State *L) { - luaL_setfuncs(L, func, 0); + lua_newtable(L); /* metatable for wrapped exceptions */ + lua_pushboolean(L, 0); + lua_setfield(L, -2, "__metatable"); + luaL_setfuncs(L, func, 1); return 0; } diff --git a/src/except.h b/src/except.h index 1e7a245..2497c05 100644 --- a/src/except.h +++ b/src/except.h @@ -9,21 +9,26 @@ * error checking was taking a substantial amount of the coding. These * function greatly simplify the task of checking errors. * -* The main idea is that functions should return nil as its first return -* value when it finds an error, and return an error message (or value) +* The main idea is that functions should return nil as their first return +* values when they find an error, and return an error message (or value) * following nil. In case of success, as long as the first value is not nil, * the other values don't matter. * * The idea is to nest function calls with the "try" function. This function -* checks the first value, and calls "error" on the second if the first is -* nil. Otherwise, it returns all values it received. +* checks the first value, and, if it's falsy, wraps the second value in a +* table with metatable and calls "error" on it. Otherwise, it returns all +* values it received. Basically, it works like the Lua "assert" function, +* but it creates errors targeted specifically at "protect". * -* The protect function returns a new function that behaves exactly like the -* function it receives, but the new function doesn't throw exceptions: it -* returns nil followed by the error message instead. +* The "newtry" function is a factory for "try" functions that call a +* finalizer in protected mode before calling "error". * -* With these two function, it's easy to write functions that throw -* exceptions on error, but that don't interrupt the user script. +* The "protect" function returns a new function that behaves exactly like +* the function it receives, but the new function catches exceptions thrown +* by "try" functions and returns nil followed by the error message instead. +* +* With these three functions, it's easy to write functions that throw +* exceptions on error, but that don't interrupt the user script. \*=========================================================================*/ #include "lua.h" diff --git a/test/excepttest.lua b/test/excepttest.lua index ce9f197..6904545 100644 --- a/test/excepttest.lua +++ b/test/excepttest.lua @@ -1,6 +1,31 @@ local socket = require("socket") -try = socket.newtry(function() - print("finalized!!!") + +local finalizer_called + +local func = socket.protect(function(err, ...) + local try = socket.newtry(function() + finalizer_called = true + error("ignored") + end) + + if err then + return error(err, 0) + else + return try(...) + end end) -try = socket.protect(try) -print(try(nil, "it works")) + +local ret1, ret2, ret3 = func(false, 1, 2, 3) +assert(not finalizer_called, "unexpected finalizer call") +assert(ret1 == 1 and ret2 == 2 and ret3 == 3, "incorrect return values") + +ret1, ret2, ret3 = func(false, false, "error message") +assert(finalizer_called, "finalizer not called") +assert(ret1 == nil and ret2 == "error message" and ret3 == nil, "incorrect return values") + +local err = {key = "value"} +ret1, ret2 = pcall(func, err) +assert(not ret1, "error not rethrown") +assert(ret2 == err, "incorrect error rethrown") + +print("OK")