LTN12 bug removed.

This commit is contained in:
Diego Nehab 2004-11-28 02:36:07 +00:00
parent 05e8f24385
commit 297b32e828
4 changed files with 39 additions and 16 deletions

View File

@ -40,30 +40,27 @@ end
function filter.chain(...) function filter.chain(...)
local n = table.getn(arg) local n = table.getn(arg)
local top, index = 1, 1 local top, index = 1, 1
local retry = ""
return function(chunk) return function(chunk)
retry = chunk and retry
while true do while true do
if index == top then if index == top then
chunk = arg[index](chunk) chunk = arg[index](chunk)
if chunk == "" or top == n then if chunk == "" or top == n then return chunk
return chunk elseif chunk then index = index + 1
elseif chunk then
index = index + 1
else else
top = top+1 top = top+1
index = top index = top
end end
else else
local original = chunk chunk = arg[index](chunk or "")
chunk = arg[index](original or "")
if chunk == "" then if chunk == "" then
index = index - 1 index = index - 1
chunk = original and chunk chunk = retry
elseif chunk then elseif chunk then
if index == n then return chunk if index == n then return chunk
else index = index + 1 end else index = index + 1 end
else else base.error("filter returned inappropriate nil") end
base.error("filter returned inappropriate nil")
end
end end
end end
end end
@ -138,6 +135,8 @@ function source.rewind(src)
end end
end end
local print = print
-- chains a source with a filter -- chains a source with a filter
function source.chain(src, f) function source.chain(src, f)
base.assert(src and f) base.assert(src and f)
@ -151,13 +150,16 @@ function source.chain(src, f)
last_out = f(last_in) last_out = f(last_in)
if last_out ~= "" then return last_out end if last_out ~= "" then return last_out end
if not last_in then if not last_in then
error('filter returned inappropriate ""') base.error('filter returned inappropriate ""')
end end
end end
elseif last_out then elseif last_out then
last_out = f(last_in and "") last_out = f(last_in and "")
if last_in and not last_out then if last_in and not last_out then
error('filter returned inappropriate nil') base.error('filter returned inappropriate nil')
end
if last_out == "" and not last_in then
base.error(base.tostring(f) .. ' returned inappropriate ""')
end end
return last_out return last_out
else else

View File

@ -267,6 +267,7 @@ static int mime_global_b64(lua_State *L)
if (!input) { if (!input) {
asize = b64pad(atom, asize, &buffer); asize = b64pad(atom, asize, &buffer);
luaL_pushresult(&buffer); luaL_pushresult(&buffer);
if (!(*lua_tostring(L, -1))) lua_pushnil(L);
lua_pushnil(L); lua_pushnil(L);
return 2; return 2;
} }
@ -306,6 +307,7 @@ static int mime_global_unb64(lua_State *L)
/* if second is nil, we are done */ /* if second is nil, we are done */
if (!input) { if (!input) {
luaL_pushresult(&buffer); luaL_pushresult(&buffer);
if (!(*lua_tostring(L, -1))) lua_pushnil(L);
lua_pushnil(L); lua_pushnil(L);
return 2; return 2;
} }
@ -418,7 +420,7 @@ static size_t qppad(UC *input, size_t size, luaL_Buffer *buffer)
if (qpclass[input[i]] == QP_PLAIN) luaL_putchar(buffer, input[i]); if (qpclass[input[i]] == QP_PLAIN) luaL_putchar(buffer, input[i]);
else qpquote(input[i], buffer); else qpquote(input[i], buffer);
} }
luaL_addstring(buffer, EQCRLF); if (size > 0) luaL_addstring(buffer, EQCRLF);
return 0; return 0;
} }
@ -454,7 +456,9 @@ static int mime_global_qp(lua_State *L)
if (!input) { if (!input) {
asize = qppad(atom, asize, &buffer); asize = qppad(atom, asize, &buffer);
luaL_pushresult(&buffer); luaL_pushresult(&buffer);
if (!(*lua_tostring(L, -1))) lua_pushnil(L);
lua_pushnil(L); lua_pushnil(L);
return 2;
} }
/* otherwise process rest of input */ /* otherwise process rest of input */
last = input + isize; last = input + isize;
@ -531,6 +535,7 @@ static int mime_global_unqp(lua_State *L)
/* if second part is nil, we are done */ /* if second part is nil, we are done */
if (!input) { if (!input) {
luaL_pushresult(&buffer); luaL_pushresult(&buffer);
if (!(*lua_tostring(L, -1))) lua_pushnil(L);
lua_pushnil(L); lua_pushnil(L);
return 2; return 2;
} }

View File

@ -49,6 +49,15 @@ mime.decodet['quoted-printable'] = function()
return ltn12.filter.cycle(unqp, "") return ltn12.filter.cycle(unqp, "")
end end
local io, string = io, string
local function format(chunk)
if chunk then
if chunk == "" then return "''"
else return string.len(chunk) end
else return "nil" end
end
-- define the line-wrap filters -- define the line-wrap filters
mime.wrapt['text'] = function(length) mime.wrapt['text'] = function(length)
length = length or 76 length = length or 76

View File

@ -39,9 +39,13 @@ local mao = [[
local function random(handle, io_err) local function random(handle, io_err)
if handle then if handle then
return function() return function()
if not handle then error("source is empty!", 2) end
local len = math.random(0, 1024) local len = math.random(0, 1024)
local chunk = handle:read(len) local chunk = handle:read(len)
if not chunk then handle:close() end if not chunk then
handle:close()
handle = nil
end
return chunk return chunk
end end
else return ltn12.source.empty(io_err or "unable to open file") end else return ltn12.source.empty(io_err or "unable to open file") end
@ -73,6 +77,7 @@ local function encode_qptest(mode)
end end
local function compare_qptest() local function compare_qptest()
io.write("testing qp encoding and wrap: ")
compare(qptest, dqptest) compare(qptest, dqptest)
end end
@ -173,7 +178,6 @@ local function encode_b64test()
local sp2 = mime.wrap("base64", 30) local sp2 = mime.wrap("base64", 30)
local sp1 = mime.wrap(27) local sp1 = mime.wrap(27)
local chain = ltn12.filter.chain(e1, sp1, e2, sp2, e3, sp3, e4, sp4) local chain = ltn12.filter.chain(e1, sp1, e2, sp2, e3, sp3, e4, sp4)
chain = socket.protect(chain)
transform(b64test, eb64test, chain) transform(b64test, eb64test, chain)
end end
@ -193,10 +197,12 @@ local function cleanup_b64test()
end end
local function compare_b64test() local function compare_b64test()
io.write("testing b64 chained encode: ")
compare(b64test, db64test) compare(b64test, db64test)
end end
local function identity_test() local function identity_test()
io.write("testing identity: ")
local chain = named(ltn12.filter.chain( local chain = named(ltn12.filter.chain(
named(mime.encode("quoted-printable"), "1 eq"), named(mime.encode("quoted-printable"), "1 eq"),
named(mime.encode("base64"), "2 eb"), named(mime.encode("base64"), "2 eb"),
@ -223,11 +229,12 @@ local function chunkcheck(original, encoded)
local b = string.sub(original, i+1) local b = string.sub(original, i+1)
local e, r = mime.b64(a, b) local e, r = mime.b64(a, b)
local f = (mime.b64(r)) local f = (mime.b64(r))
if (e .. f ~= encoded) then fail(e .. f) end if (e .. (f or "") ~= encoded) then fail(e .. (f or "")) end
end end
end end
local function padding_b64test() local function padding_b64test()
io.write("testing b64 padding: ")
padcheck("a", "YQ==") padcheck("a", "YQ==")
padcheck("ab", "YWI=") padcheck("ab", "YWI=")
padcheck("abc", "YWJj") padcheck("abc", "YWJj")